使用Pytorch实现two-head(多输出)模型的操作


Posted in Python onMay 28, 2021

如何使用Pytorch实现two-head(多输出)模型

1. two-head模型定义

先放一张我要实现的模型结构图:

使用Pytorch实现two-head(多输出)模型的操作

如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。

2.实现所遇到的困难 一开始的想法:

这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。

灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。

但是后来搜如何进行模型串联,发现极为麻烦。

3.解决方案

后来在pytorch的官方社区中看到一个极为简单的方法:

(1) 按照一般的多输出模型进行实现,代码如下:

def forward(self, x):
        #三层的表示层
        x = F.elu(self.fcR1(x))
        x = F.elu(self.fcR2(x))
        x = F.elu(self.fcR3(x))
		#two-head,两个head分别进行输出
        y0 = F.elu(self.fcH01(x))
        y0 = F.elu(self.fcH02(y0))
        y0 = F.elu(self.fcH03(y0))
        y1 = F.elu(self.fcH11(x))
        y1 = F.elu(self.fcH12(y1))
        y1 = F.elu(self.fcH13(y1))
        return y0, y1

这样就相当实现了一个多输出模型,一个x同时输出y0和y1.

训练的时候分别训练,也即分别建loss,代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。

4.后记

我自以为多输出模型可以分为以下两类:

多个输出不同时获得,如本文情况。

多个输出同时获得。

多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。

补充:PyTorch 多输入多输出模型构建

本篇教程基于 PyTorch 1.5版本

直接上代码!

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden1 = nn.Linear(n_input, n_hidden)
        self.hidden2 = nn.Linear(n_hidden, n_hidden)
        self.predict1 = nn.Linear(n_hidden*2, n_output)
        self.predict2 = nn.Linear(n_hidden*2, n_output)
    def forward(self, input1, input2): # 多输入!!!
        out01 = self.hidden1(input1)
        out02 = torch.relu(out01)
        out03 = self.hidden2(out02)
        out04 = torch.sigmoid(out03)
        out11 = self.hidden1(input2)
        out12 = torch.relu(out11)
        out13 = self.hidden2(out12)
        out14 = torch.sigmoid(out13)
        out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
 
        out1 = self.predict1(out)
        out2 = self.predict2(out)
        return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
    prediction1, prediction2 = net(x1, x2)
    loss1 = loss_func(prediction1, y1)
    loss2 = loss_func(prediction2, y2)
    loss = loss1 + loss2 # 重点!
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
       print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)

至此搞定!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现提取百度搜索结果的方法
May 19 Python
在Python程序中操作文件之flush()方法的使用教程
May 24 Python
Python画柱状统计图操作示例【基于matplotlib库】
Jul 04 Python
Flask框架实现给视图函数增加装饰器操作示例
Jul 16 Python
python实现邮件自动发送
Aug 10 Python
解决Django migrate不能发现app.models的表问题
Aug 31 Python
python将邻接矩阵输出成图的实现
Nov 21 Python
Python安装tar.gz格式文件方法详解
Jan 19 Python
python 日志模块 日志等级设置失效的解决方案
May 26 Python
Python+OpenCV检测灯光亮点的实现方法
Nov 02 Python
使用qt quick-ListView仿微信好友列表和聊天列表的示例代码
Jun 13 Python
python前后端自定义分页器
Apr 13 Python
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
You might like
地摊中国 - 珍藏老照片
2020/08/18 杂记
编写漂亮的代码 - 将后台程序与前端程序分开
2008/04/23 PHP
浅谈json_encode用法
2015/03/05 PHP
PHP下载生成的csv文件及问题总结
2015/08/06 PHP
通过javascript的匿名函数来分析几段简单有趣的代码
2010/06/29 Javascript
关于JavaScript中的关联数组分析
2013/04/09 Javascript
用原生JavaScript实现jQuery的$.getJSON的解决方法
2013/05/03 Javascript
可兼容IE的获取及设置cookie的jquery.cookie函数方法
2013/09/02 Javascript
js针对ip地址、子网掩码、网关的逻辑性判断
2016/01/06 Javascript
JS操作JSON方法总结(推荐)
2016/06/14 Javascript
angularjs 中$apply,$digest,$watch详解
2016/10/13 Javascript
JavaScript实现经纬度转换成地址功能
2017/03/28 Javascript
解决循环中setTimeout执行顺序的问题
2018/06/20 Javascript
JS使用new操作符创建对象的方法分析
2019/05/30 Javascript
使用layui实现的左侧菜单栏以及动态操作tab项方法
2019/09/10 Javascript
npm ci命令的基本使用方法
2020/09/20 Javascript
swiper4实现移动端导航栏tab滑动切换
2020/10/16 Javascript
[02:28]DOTA2 2015国际邀请赛中国区预选赛首日现场百态
2015/05/26 DOTA
[28:48]《真视界》- 2017年国际邀请赛
2017/09/27 DOTA
Python处理RSS、ATOM模块FEEDPARSER介绍
2015/02/18 Python
python执行get提交的方法
2015/04/29 Python
详解Golang 与python中的字符串反转
2017/07/21 Python
TF-IDF与余弦相似性的应用(二) 找出相似文章
2017/12/21 Python
python使用scrapy发送post请求的坑
2018/09/04 Python
Python实现合并两个有序链表的方法示例
2019/01/31 Python
python3实现简单飞机大战
2020/11/29 Python
基于Python 函数和方法的区别说明
2021/03/24 Python
高中的职业生涯规划书
2013/12/28 职场文书
医务工作者先进事迹材料
2014/01/26 职场文书
酒店总经理岗位职责
2014/03/17 职场文书
《1942》观后感
2015/06/08 职场文书
学校隐患排查制度
2015/08/05 职场文书
煤矿施工安全协议书
2016/03/22 职场文书
创业计划书之农家乐
2019/10/09 职场文书
健身房被搭讪?用python写了个小米计时器助人为乐
2021/06/08 Python
分享五个Node.js开发的优秀实践 
2022/04/07 NodeJs