使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python paramiko实现ssh远程访问的方法
Dec 03 Python
Python3实现生成随机密码的方法
Aug 23 Python
零基础写python爬虫之爬虫的定义及URL构成
Nov 04 Python
在Linux系统上部署Apache+Python+Django+MySQL环境
Dec 24 Python
Python二叉树定义与遍历方法实例分析
May 25 Python
python 将字符串中的数字相加求和的实现
Jul 18 Python
pycharm 批量修改变量名称的方法
Aug 01 Python
pandas中的数据去重处理的实现方法
Feb 10 Python
Python数据可视化处理库PyEcharts柱状图,饼图,线性图,词云图常用实例详解
Feb 10 Python
使用python爬取抖音app视频的实例代码
Dec 01 Python
python 基于pygame实现俄罗斯方块
Mar 02 Python
python实现简单倒计时功能
Apr 21 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
神盾加密解密教程(二)PHP 神盾解密
2014/06/08 PHP
PHP输出一个等腰三角形的方法
2015/05/12 PHP
PHP IDE PHPStorm配置支持友好Laravel代码提示方法
2015/05/12 PHP
PHP版单点登陆实现方案的实例
2016/11/17 PHP
php输出图像的方法实例分析
2017/02/16 PHP
基于JQuery实现相同内容合并单元格的代码
2011/01/12 Javascript
图片延迟加载的实现代码(模仿懒惰)
2013/03/29 Javascript
jQuery对象的selector属性用法实例
2014/12/27 Javascript
TinyMCE提交AjaxForm获取不到数据的解决方法
2015/03/05 Javascript
基于JavaScript代码实现微信扫一扫下载APP
2015/12/30 Javascript
将Sublime Text 3 添加到右键中的简单方法
2017/12/12 Javascript
vue项目实现记住密码到cookie功能示例(附源码)
2018/01/31 Javascript
nodejs爬虫初试superagent和cheerio
2018/03/05 NodeJs
15分钟深入了解JS继承分类、原理与用法
2019/01/19 Javascript
[44:39]2014 DOTA2国际邀请赛中国区预选赛 NE VS CNB
2014/05/21 DOTA
[01:11]回顾历届DOTA2国际邀请赛中国区预选赛
2017/06/26 DOTA
Python with的用法
2014/08/22 Python
Python while、for、生成器、列表推导等语句的执行效率测试
2015/06/03 Python
破解安装Pycharm的方法
2018/10/19 Python
如何通过python的fabric包完成代码上传部署
2019/07/29 Python
Django认证系统实现的web页面实现代码
2019/08/12 Python
python防止随意修改类属性的实现方法
2019/08/21 Python
使用Python paramiko模块利用多线程实现ssh并发执行操作
2019/12/05 Python
python3.5的包存放的具体路径
2020/08/16 Python
Python爬虫如何破解JS加密的Cookie
2020/11/19 Python
使用CSS3 制作一个material-design 风格登录界面实例
2016/12/12 HTML / CSS
RIP版本1跟版本2的区别
2013/12/30 面试题
与C++相比,Java中的数组有什么不同
2014/03/25 面试题
财务专业大学生职业生涯规划范文
2013/12/30 职场文书
2014年检验科工作总结
2014/11/22 职场文书
2015年专项整治工作总结
2015/04/03 职场文书
2015年导购员工作总结
2015/04/25 职场文书
2015年大学学生会工作总结
2015/05/13 职场文书
用Python可视化新冠疫情数据
2022/01/18 Python
详细介绍python操作RabbitMq
2022/04/12 Python
python中 .npy文件的读写操作实例
2022/04/14 Python