pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中将字典转换为XML以及相关的命名空间解析
Oct 15 Python
Python聊天室实例程序分享
Jan 05 Python
Python实现检测文件MD5值的方法示例
Apr 11 Python
解决Python中pandas读取*.csv文件出现编码问题
Jul 12 Python
Python 实现数据结构-循环队列的操作方法
Jul 17 Python
python文字转语音的实例代码分析
Nov 12 Python
python操作cfg配置文件方式
Dec 22 Python
python关于变量名的基础知识点
Mar 03 Python
pyinstaller打包单文件时--uac-admin选项不起作用怎么办
Apr 15 Python
如何快速理解python的垃圾回收机制
Sep 01 Python
利用Python中的Xpath实现一个在线汇率转换器
Sep 09 Python
Python命令行参数定义及需要注意的地方
Nov 30 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
Zend 输出产生XML解析错误
2009/03/03 PHP
php visitFile()遍历指定文件夹函数
2010/08/21 PHP
一个php短网址的生成代码(仿微博短网址)
2014/05/07 PHP
PHP高级编程实例:编写守护进程
2014/09/02 PHP
PHP中Closure类的使用方法及详解
2015/10/09 PHP
PHP下使用mysqli的函数连接mysql出现warning: mysqli::real_connect(): (hy000/1040): ...
2016/02/14 PHP
php常用字符函数实例小结
2016/12/29 PHP
Yii2框架中使用PHPExcel导出Excel文件的示例
2017/08/09 PHP
thinkphp集成前端脚手架Vue-cli的教程图解
2018/08/30 PHP
javascript脚本调试方法小结
2008/11/24 Javascript
js 幻灯片的实现
2011/12/06 Javascript
使用jquery实现简单的ajax
2013/07/08 Javascript
jquery中event对象属性与方法小结
2013/12/18 Javascript
Nodejs极简入门教程(三):进程
2014/10/27 NodeJs
RGB和YUV 多媒体编程基础详细介绍
2016/11/04 Javascript
浅谈基于Vue.js的移动组件库cube-ui
2017/12/20 Javascript
jQuery实现移动端图片上传预览组件的方法分析
2020/05/01 jQuery
python人人网登录应用实例
2014/09/26 Python
Python中创建字典的几种方法总结(推荐)
2017/04/27 Python
使用Python OpenCV为CNN增加图像样本的实现
2019/06/10 Python
Python数据库小程序源代码
2019/09/15 Python
使用Python的networkx绘制精美网络图教程
2019/11/21 Python
pytorch .detach() .detach_() 和 .data用于切断反向传播的实现
2019/12/27 Python
Python3列表List入门知识附实例
2020/02/09 Python
python中matplotlib实现随鼠标滑动自动标注代码
2020/04/23 Python
Pycharm插件(Grep Console)自定义规则输出颜色日志的方法
2020/05/27 Python
Python读取pdf表格写入excel的方法
2021/01/22 Python
全球性的女装店:storets
2019/06/12 全球购物
导游实习生自荐书
2014/01/28 职场文书
民主生活会对照检查材料范文
2014/10/01 职场文书
2015大学生党员自我评价范文
2015/03/03 职场文书
语文教师求职信范文
2015/03/20 职场文书
平凡的世界读书笔记
2015/06/25 职场文书
MySQL Shell的介绍以及安装
2021/04/24 MySQL
vue实现列表拖拽排序的示例代码
2022/04/08 Vue.js
vue配置型表格基于el-table拓展之table-plus组件
2022/04/12 Vue.js