使用Pytorch训练two-head网络的操作


Posted in Python onMay 28, 2021

以前的训练方法:

之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:

f_out_y0, _ = net(x0)
            _, f_out_y1 = net(x1)
            #实例化损失函数
            criterion0 = Loss()
            criterion1 = Loss()
            loss0 = criterion0(f_y0, f_out_y0, w0)
            loss1 = criterion1(f_y1, f_out_y1, w1)
            print(loss0.item(), loss1.item())
            #对网络参数进行初始化
            optimizer.zero_grad()
            loss0.backward()
            loss1.backward()
            #对网络的参数进行更新
            optimizer.step()

但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。

修改后的方法

使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。

因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。

代码如下:

f_out_y0, cf_out_y0 = net(x0)
            cf_out_y1, f_out_y1 = net(x1)
            #按照t=0和t=1的索引拼接向量
            y_pred = torch.zeros([len(x), 1])
            y_pred[index0] = f_out_y0
            y_pred[index1] = f_out_y1
            
   criterion = Loss()
            loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
            #print(loss.item())
            viz.line([float(loss)], [epoch], win='train_loss', update='append')
            optimizer.zero_grad()
            loss.backward()
            #对网络的参数进行更新
            optimizer.step()

总结

two-head网络前向传播得到结果的时候是分开得到的,训练的时候通过拼接预测结果可以实现一次训练。

补充:Pytorch训练网络的一般步骤

如下所示:

import torch 
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数

1、构建模型

2、定义一个损失函数

3、定义一个优化器

4、将训练数据带入模型得到预测值

5、将梯度清零

6、获得损失

7、进行优化

import torch
from torch.autograd import Variable
 
#初步认识构建Tensor数据
def one():
    print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
    print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
    print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
    print(torch.ones((2,3)))#生成一个2X3的全一矩阵
    a = torch.randn((2,3))
    b = a.numpy()#将一个torch.Tensor转换为numpy
    c = torch.from_numpy(b)#将numpy转换为Tensor
    print(a)
    print(b)
    print(c)
 
#使用Variable自动求导
def two():
    # 构建Variable
    x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
    w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
    b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
    # 函数等式
    y = w * x ** 2 + b
    # 使用梯度下降计算各变量的偏导数
    y.backward(torch.Tensor([1, 1, 1]))
    print(x.grad)
    print(w.grad)
    print(b.grad)

线性回归例子:

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
 
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.Linear = nn.Linear(1,1)
    def forward(self,x):
        return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
    inputs = Variable(x)
    targets = Variable(y)
    outputs = model(inputs)
    loss = Loss(outputs,targets)
    Opt.zero_grad()
    loss.backward()
    Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python修改注册表终止360进程实例
Oct 13 Python
python使用socket远程连接错误处理方法
Apr 29 Python
python僵尸进程产生的原因
Jul 21 Python
Python迭代器和生成器定义与用法示例
Feb 10 Python
python实现音乐下载器
Apr 15 Python
Python unittest模块用法实例分析
May 25 Python
python中实现字符串翻转的方法
Jul 11 Python
Python实现正整数分解质因数操作示例
Aug 01 Python
详解Ubuntu16.04安装Python3.7及其pip3并切换为默认版本
Feb 25 Python
python scipy卷积运算的实现方法
Sep 16 Python
用Python爬虫破解滑动验证码的案例解析
May 06 Python
Django操作cookie的实现
May 26 Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
pytorch锁死在dataloader(训练时卡死)
You might like
PHP系统流量分析的程序
2006/10/09 PHP
PHP文件大小格式化函数合集
2014/03/10 PHP
ThinkPHP CURD方法之page方法详解
2014/06/18 PHP
教你如何解密 “ PHP 神盾解密工具 ”
2014/06/20 PHP
举例详解PHP脚本的测试方法
2015/08/05 PHP
JS 继承实例分析
2008/11/04 Javascript
Javascript学习笔记二 之 变量
2010/12/15 Javascript
javascript基础知识大集锦(二) 推荐收藏
2011/01/13 Javascript
Node.js:Windows7下搭建的Node.js服务(来玩玩服务器端的javascript吧,这可不是前端js插件)
2011/06/27 Javascript
jQuery实现带滚动导航效果的全屏滚动相册实例
2015/06/19 Javascript
jQuery 实现评论等级好评差评特效
2016/05/06 Javascript
Vue.js报错Failed to resolve filter问题的解决方法
2016/05/25 Javascript
微信小程序 生命周期详解
2016/10/12 Javascript
AngularJS控制器controller给模型数据赋初始值的方法
2017/01/04 Javascript
用jquery的attr方法实现图片切换效果
2017/02/05 Javascript
Angular 4依赖注入学习教程之FactoryProvider配置依赖对象(五)
2017/06/04 Javascript
使用Vue如何写一个双向数据绑定(面试常见)
2018/04/20 Javascript
javascript将16进制的字符串转换为10进制整数hex
2020/03/05 Javascript
JavaScript如何使用插值实现图像渐变
2020/06/28 Javascript
node koa2 ssr项目搭建的方法步骤
2020/12/11 Javascript
[48:00]完美世界DOTA2联赛循环赛 Forest vs Inki BO2第二场 11.04
2020/11/04 DOTA
python使用datetime模块计算各种时间间隔的方法
2015/03/24 Python
selenium 多窗口切换的实现(windows)
2020/01/18 Python
pycharm部署、配置anaconda环境的教程
2020/03/24 Python
Python环境下安装PyGame和PyOpenGL的方法
2020/03/25 Python
浅谈django channels 路由误导
2020/05/28 Python
一款超酷的js+css3实现的3D标签云特效兼容ie7/8/9
2013/11/18 HTML / CSS
南非领先的在线旅行社:Travelstart南非
2016/09/04 全球购物
求最大连续递增数字串(如"ads3sl456789DF3456ld345AA"中的"456789")
2015/09/11 面试题
大学生求职自荐信
2013/12/12 职场文书
大学军训通讯稿
2014/01/13 职场文书
2014年银行员工工作总结
2014/11/12 职场文书
2014年消防工作总结
2014/11/21 职场文书
职业生涯规划书之大学四年
2019/08/07 职场文书
详解Python生成器和基于生成器的协程
2021/06/03 Python
HTML中的表格元素介绍
2022/02/28 HTML / CSS