pytorch 在网络中添加可训练参数,修改预训练权重文件的方法


Posted in Python onAugust 17, 2019

实践中,针对不同的任务需求,我们经常会在现成的网络结构上做一定的修改来实现特定的目的。

假如我们现在有一个简单的两层感知机网络:

# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
import torch.optim as optim
 
x = Variable(torch.FloatTensor([1, 2, 3])).cuda()
y = Variable(torch.FloatTensor([4, 5])).cuda()
 
class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)
 
    return x
 
model = MLP().cuda()
 
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
for t in range(500):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)
  print(t, loss.data[0])
  model.zero_grad()
  loss.backward()
  optimizer.step()
 
print(model(x))

现在想在前向传播时,在relu之后给x乘以一个可训练的系数,只需要在__init__函数中添加一个nn.Parameter类型变量,并在forward函数中乘以该变量即可:

class MLP(torch.nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    self.linear1 = torch.nn.Linear(3, 5)
    self.relu = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(5, 2)
    # the para to be added and updated in train phase, note that NO cuda() at last
    self.coefficient = torch.nn.Parameter(torch.Tensor([1.55]))
 
  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.coefficient * x
    x = self.linear2(x)
 
    return x

注意,Parameter变量和Variable变量的操作大致相同,但是不能手动调用.cuda()方法将其加载在GPU上,事实上它会自动在GPU上加载,可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得到的父类中的参数):

print(model.state_dict().keys())
for i, j in model.named_parameters():
  print(i)
  print(j)

输出如下:

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])
linear1.weight
Parameter containing:
-0.3582 -0.0283 0.2607
 0.5190 -0.2221 0.0665
-0.2586 -0.3311 0.1927
-0.2765 0.5590 -0.2598
 0.4679 -0.2923 -0.3379
[torch.cuda.FloatTensor of size 5x3 (GPU 0)]
 
linear1.bias
Parameter containing:
-0.2549
-0.5246
-0.1109
 0.5237
-0.1362
[torch.cuda.FloatTensor of size 5 (GPU 0)]
 
linear2.weight
Parameter containing:
-0.0286 -0.3045 0.1928 -0.2323 0.2966
 0.2601 0.1441 -0.2159 0.2484 0.0544
[torch.cuda.FloatTensor of size 2x5 (GPU 0)]
 
linear2.bias
Parameter containing:
-0.4038
 0.3129
[torch.cuda.FloatTensor of size 2 (GPU 0)]

这个参数会在反向传播时与原有变量同时参与更新,这就达到了添加可训练参数的目的。

如果我们有原先网络的预训练权重,现在添加了一个新的参数,原有的权重文件自然就不能加载了,我们需要修改原权重文件,在其中添加我们的新变量的初始值。

调用model.state_dict查看我们添加的参数在参数字典中的完整名称,然后打开原先的权重文件:

a = torch.load("OldWeights.pth") a是一个collecitons.OrderedDict类型变量,也就是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

a = torch.load("OldWeights.pth")
 
a["layer1.0.coefficient"] = torch.FloatTensor([1.2])
a["layer1.1.coefficient"] = torch.FloatTensor([1.5])
 
torch.save(a, "Weights.pth")

现在权重就可以加载在修改后的模型上了。

以上这篇pytorch 在网络中添加可训练参数,修改预训练权重文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的各种排序算法代码
Mar 04 Python
Python pass详细介绍及实例代码
Nov 24 Python
利用python生成一个导出数据库的bat脚本文件的方法
Dec 30 Python
用python实现k近邻算法的示例代码
Sep 06 Python
Django2.1集成xadmin管理后台所遇到的错误集锦(填坑)
Dec 20 Python
Numpy之random函数使用学习
Jan 29 Python
Django 重写用户模型的实现
Jul 29 Python
Django自带的加密算法及加密模块详解
Dec 03 Python
python 截取XML中bndbox的坐标中的图像,另存为jpg的实例
Mar 10 Python
Python读入mnist二进制图像文件并显示实例
Apr 24 Python
简述python&pytorch 随机种子的实现
Oct 07 Python
用python开发一款操作MySQL的小工具
May 12 Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
You might like
Laravel等框架模型关联的可用性浅析
2019/12/15 PHP
判断对象是否Window的实现代码
2012/01/10 Javascript
利用jquery的获取JS文件中的字符串内容
2012/02/14 Javascript
JavaScript实现的GBK、UTF8字符串实际长度计算函数
2014/08/27 Javascript
jQuery中借助deferred来请求及判断AJAX加载的实例讲解
2016/05/24 Javascript
jQuery实现返回顶部按钮和scroll滚动功能[带动画效果]
2017/07/05 jQuery
详解angular脏检查原理及伪代码实现
2018/06/08 Javascript
vue.js中toast用法及使用toast弹框的实例代码
2018/08/27 Javascript
浅谈VUE单页应用首屏加载速度优化方案
2018/08/28 Javascript
深入理解NodeJS 多进程和集群
2018/10/17 NodeJs
微信小程序地图(map)组件点击(tap)获取经纬度的方法
2019/01/10 Javascript
前端面试知识点目录一览
2019/04/15 Javascript
JS实现的定时器展示简单秒表、页面弹框及跳转操作完整示例
2020/01/26 Javascript
python创建和删除目录的方法
2015/04/29 Python
Python自动扫雷实现方法
2015/07/25 Python
解决pycharm运行程序出现卡住scanning files to index索引的问题
2019/06/27 Python
关于Python中的向量相加和numpy中的向量相加效率对比
2019/08/26 Python
python如何进入交互模式
2020/07/06 Python
Flask中sqlalchemy模块的实例用法
2020/08/02 Python
python接口自动化之ConfigParser配置文件的使用详解
2020/08/03 Python
如何解决python多种版本冲突问题
2020/10/13 Python
python 用Matplotlib作图中有多个Y轴
2020/11/28 Python
CSS3制作日历实现代码
2012/01/21 HTML / CSS
前后端结合实现amazeUI分页效果
2020/08/21 HTML / CSS
Kipling凯浦林美国官网:世界著名时尚休闲包袋品牌
2016/08/24 全球购物
美国设计师精美珠宝购物网:Netaya
2016/08/28 全球购物
Snapfish英国:在线照片打印和个性化照片礼品
2017/01/13 全球购物
菲律宾酒店预订网站:Hotels.com菲律宾
2017/07/12 全球购物
美国新兴城市生活方式零售商:VILLA
2017/12/06 全球购物
瑞典网上购买现代和复古家具:Reforma
2019/10/21 全球购物
什么是ARP(Address Resolution Protocol)地址解析协议
2013/10/31 面试题
第一批党的群众路线教育实践活动工作总结
2014/03/03 职场文书
幼儿园门卫岗位职责范本
2014/07/02 职场文书
工作表扬信范文
2015/01/17 职场文书
vue如何批量引入组件、注册和使用详解
2021/05/12 Vue.js
Java 死锁解决方案
2022/05/11 Java/Android