pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python装饰器入门学习教程(九步学习)
Jan 28 Python
Python2.7基于淘宝接口获取IP地址所在地理位置的方法【测试可用】
Jun 07 Python
Python 中Pickle库的使用详解
Feb 24 Python
python检测主机的连通性并记录到文件的实例
Jun 21 Python
Python中的取模运算方法
Nov 10 Python
python实现大转盘抽奖效果
Jan 22 Python
python web框架Flask实现图形验证码及验证码的动态刷新实例
Oct 14 Python
如何基于线程池提升request模块效率
Apr 18 Python
Keras 使用 Lambda层详解
Jun 10 Python
详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程
Nov 02 Python
python脚本框架webpy模板控制结构
Nov 20 Python
Python中使用tkFileDialog实现文件选择、保存和路径选择
May 20 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
剧场版动画《PSYCHO-PASS 3 FIRST INSPECTOR》3月27日日本上映!
2020/03/06 日漫
str_replace只替换一次字符串的方法
2013/04/09 PHP
关于PHP通用返回值设置方法
2017/03/31 PHP
php异常处理捕获错误整理
2019/09/23 PHP
express的中间件basicAuth详解
2014/12/04 Javascript
jquery使用经验小结
2015/05/20 Javascript
JavaScript中函数表达式和函数声明及函数声明与函数表达式的不同
2015/11/15 Javascript
AngularJS中的Directive自定义一个表格
2016/01/25 Javascript
js实现页面跳转的五种方法推荐
2016/03/10 Javascript
JavaScript位移运算符(无符号) >>> 三个大于号 的使用方法详解
2016/03/31 Javascript
js获取腾讯视频ID的方法
2016/10/03 Javascript
layer弹出层中H5播放器全屏出错的解决方法
2017/02/21 Javascript
深入理解nodejs中Express的中间件
2017/05/19 NodeJs
利用Vue.js实现求职在线之职位查询功能
2017/07/03 Javascript
微信小程序wx.getImageInfo()如何获取图片信息
2018/01/26 Javascript
还不懂递归?读完这篇文章保证你会懂
2018/07/29 Javascript
vue-cli3中vue.config.js配置教程详解
2019/05/29 Javascript
Nuxt使用Vuex的方法示例
2019/09/06 Javascript
关于layui toolbar和template的结合使用方法
2019/09/19 Javascript
vue 实现在同一界面实现组件的动态添加和删除功能
2020/06/16 Javascript
[01:43]3.19DOTA2发布会 三代刀塔人第三代
2014/03/25 DOTA
[33:33]完美世界DOTA2联赛PWL S2 FTD.C vs SZ 第二场 11.27
2020/11/30 DOTA
django session完成状态保持的方法
2018/11/27 Python
对Python中的条件判断、循环以及循环的终止方法详解
2019/02/08 Python
Python3.5 Pandas模块之Series用法实例分析
2019/04/23 Python
python在不同条件下的输入与输出
2020/02/13 Python
Python响应对象text属性乱码解决方案
2020/03/31 Python
python框架flask入门之路由及简单实现方法
2020/06/07 Python
python实现猜拳游戏项目
2020/11/30 Python
介绍一下SQL Server里面的索引视图
2016/07/31 面试题
Linux面试经常问的文件系统操作命令
2015/11/05 面试题
标准自荐信范文
2014/01/29 职场文书
房屋鉴定委托书范本
2014/09/23 职场文书
JavaScript 防篡改对象的用法示例
2021/04/24 Javascript
JavaScript展开运算符和剩余运算符的区别详解
2022/02/18 Javascript
JS class语法糖的深入剖析
2022/07/07 Javascript