使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
粗略分析Python中的内存泄漏
Apr 23 Python
详解python里使用正则表达式的分组命名方式
Oct 24 Python
用Python写王者荣耀刷金币脚本
Dec 21 Python
python实现学生管理系统
Jan 11 Python
python爬取微信公众号文章
Aug 31 Python
python得到windows自启动列表的方法
Oct 14 Python
python dlib人脸识别代码实例
Apr 04 Python
python django model联合主键的例子
Aug 06 Python
python手机号前7位归属地爬虫代码实例
Mar 31 Python
python利用proxybroker构建爬虫免费IP代理池的实现
Feb 21 Python
python批量创建变量并赋值操作
Jun 03 Python
Python 键盘事件详解
Nov 11 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
劣质的PHP代码简化
2010/02/08 PHP
并发下常见的加锁及锁的PHP具体实现代码
2010/10/12 PHP
PHP实现把文本中的URL转换为链接的auolink()函数分享
2014/07/29 PHP
详解PHP的Yii框架的运行机制及其路由功能
2016/03/17 PHP
php上传图片类及用法示例
2016/05/11 PHP
PHP连接MYSQL数据库的3种常用方法
2017/02/27 PHP
线路分流自动跳转代码;希望对大家有用!
2006/12/02 Javascript
解javascript 混淆加密收藏
2009/01/16 Javascript
使用 Node.js 做 Function Test实现方法
2013/10/25 Javascript
点击页面其它地方隐藏该div的两种思路
2013/11/18 Javascript
再谈Jquery Ajax方法传递到action(补充)
2014/05/12 Javascript
JS+Canvas 实现下雨下雪效果
2016/05/18 Javascript
jQuery复制节点用法示例(clone方法)
2016/09/08 Javascript
jquery 实现回车登录详解及实例代码
2016/10/23 Javascript
基于JavaScript实现新增内容滚动播放效果附完整代码
2017/08/24 Javascript
es6学习之解构时应该注意的点
2017/08/29 Javascript
vue2 全局变量的设置方法
2018/03/09 Javascript
Vue ElementUi同时校验多个表单(巧用new promise)
2018/06/06 Javascript
JavaScript设计模式之建造者模式实例教程
2018/07/02 Javascript
vue.js指令v-for使用以及下标索引的获取
2019/01/31 Javascript
你不知道的Vue技巧之--开发一个可以通过方法调用的组件(推荐)
2019/04/15 Javascript
jquery绑定事件 bind和on的用法与区别分析
2020/05/22 jQuery
用Python实现一个简单的多线程TCP服务器的教程
2015/05/05 Python
python 获取网页编码方式实现代码
2017/03/11 Python
python实现快速排序的示例(二分法思想)
2018/03/12 Python
浅谈python脚本设置运行参数的方法
2018/12/03 Python
Python内置random模块生成随机数的方法
2019/05/31 Python
Django使用list对单个或者多个字段求values值实例
2020/03/31 Python
5分钟快速掌握Python定时任务框架的实现
2021/01/26 Python
德国网上花店:Valentins
2018/08/15 全球购物
SQL中where和having的区别
2012/06/17 面试题
工作中个人的自我评价
2013/12/31 职场文书
揭牌仪式策划方案
2014/05/28 职场文书
2016年10月份红领巾广播稿
2015/12/21 职场文书
MySQL系列之三 基础篇
2021/07/02 MySQL
铁头也玩根德 YachtBoy YB-230......
2022/04/05 无线电