使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python原始套接字编程示例分享
Feb 21 Python
Python使用迭代器捕获Generator返回值的方法
Apr 05 Python
Python字典数据对象拆分的简单实现方法
Dec 05 Python
Python中property函数用法实例分析
Jun 04 Python
详解Python解决抓取内容乱码问题(decode和encode解码)
Mar 29 Python
python里运用私有属性和方法总结
Jul 08 Python
用sqlalchemy构建Django连接池的实例
Aug 29 Python
python 两个数据库postgresql对比
Oct 21 Python
在python3中实现更新界面
Feb 21 Python
python源文件的字符编码知识点详解
Mar 04 Python
Python+uiautomator2实现自动刷抖音视频功能
Apr 29 Python
教你如何使用Python Tkinter库制作记事本
Jun 10 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
PHP开发环境配置(MySQL数据库安装图文教程)
2010/04/28 PHP
求PHP数组最大值,最小值的代码
2011/10/31 PHP
php5.3中连接sqlserver2000的两种方法(com与ODBC)
2012/12/29 PHP
php线性表的入栈与出栈实例分析
2015/06/12 PHP
PHP实现基于mysqli的Model基类完整实例
2016/04/08 PHP
yum命令安装php7和相关扩展
2016/07/04 PHP
PHP自动补全表单的两种方法
2017/03/06 PHP
PHP读取CSV大文件导入数据库的实例
2017/07/24 PHP
php+ajax 文件上传代码实例
2019/03/18 PHP
小议Function.apply() 之一------(函数的劫持与对象的复制)
2006/11/30 Javascript
javascript textarea光标定位方法(兼容IE和FF)
2011/03/12 Javascript
jquery+css实现的红色线条横向二级菜单效果
2015/08/22 Javascript
Javascript之Math对象详解
2016/06/07 Javascript
jQuery Ajax请求后台数据并在前台接收
2016/12/10 Javascript
jQuery自定义插件详解及实例代码
2016/12/29 Javascript
微信扫码支付零云插件版实例详解
2017/04/26 Javascript
12条写出高质量JS代码的方法
2018/01/07 Javascript
详解Vue CLI3配置之filenameHashing使用和源码设计使用和源码设计
2018/08/31 Javascript
[01:53]DOTA2超级联赛专访Zhou 五年职业青春成长
2013/05/29 DOTA
python多线程http下载实现示例
2013/12/30 Python
python通过自定义isnumber函数判断字符串是否为数字的方法
2015/04/23 Python
Python中函数参数设置及使用的学习笔记
2016/05/03 Python
Python中遍历字典过程中更改元素导致异常的解决方法
2016/05/12 Python
利用python获取Ping结果示例代码
2017/07/06 Python
Python静态类型检查新工具之pyright 使用指南
2019/04/26 Python
Django用户认证系统 User对象解析
2019/08/02 Python
升级keras解决load_weights()中的未定义skip_mismatch关键字问题
2020/06/12 Python
解决pip install psycopg2出错问题
2020/07/09 Python
Django-silk性能测试工具安装及使用解析
2020/11/28 Python
.NET里面什么时候需要调用垃圾回收
2015/06/01 面试题
学习保证书怎么写
2015/02/26 职场文书
2015年财务部年度工作总结
2015/05/19 职场文书
Redis集群新增、删除节点以及动态增加内存的方法
2021/09/04 Redis
在项目中使用redis做缓存的一些思路
2021/09/14 Redis
JS封装cavans多种滤镜组件
2022/02/15 Javascript
html中两种获取标签内的值的方法
2022/06/16 jQuery