pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python模拟登陆Tom邮箱示例分享
Jan 13 Python
windows系统中python使用rar命令压缩多个文件夹示例
May 06 Python
Django处理文件上传File Uploads的实例
May 28 Python
pycharm的console输入实现换行的方法
Jan 16 Python
深入解析python中的实例方法、类方法和静态方法
Mar 11 Python
python实现可变变量名方法详解
Jul 01 Python
详解Python利用random生成一个列表内的随机数
Aug 21 Python
使用Python进行中文繁简转换的实现代码
Oct 18 Python
Python编程快速上手——Excel表格创建乘法表案例分析
Feb 28 Python
Python如何通过百度翻译API实现翻译功能
Apr 02 Python
详解pycharm的python包opencv(cv2)无代码提示问题的解决
Jan 29 Python
教你使用Pandas直接核算Excel中快递费用
May 12 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
php 三维饼图的实现代码
2008/09/28 PHP
php 目录与文件处理-郑阿奇(续)
2011/07/04 PHP
在windows服务器开启php的gd库phpinfo中未发现
2013/01/13 PHP
浅析PHP安装扩展mcrypt以及相关依赖项(PHP安装PECL扩展的方法)
2013/07/05 PHP
从wamp到xampp的升级之路
2015/04/08 PHP
PHP实现将几张照片拼接到一起的合成图片功能【便于整体打印输出】
2017/11/14 PHP
javascript笔试题目附答案@20081025_jb51.net
2008/10/26 Javascript
基于jquery的用鼠标画出可移动的div
2012/09/06 Javascript
常用的JS验证和函数汇总
2014/12/23 Javascript
jQuery中[attribute*=value]选择器用法实例
2014/12/31 Javascript
javascript实现控制文字大中小显示
2015/04/28 Javascript
js简单倒计时实现代码
2016/04/30 Javascript
浅谈angularJS中的事件
2016/07/12 Javascript
JS实现隐藏同级元素后只显示JS文件内容的方法
2016/09/04 Javascript
AngularJS模仿Form表单提交的实现代码
2016/12/08 Javascript
ES6新特性之模块Module用法详解
2017/04/01 Javascript
微信小程序自定义模态对话框实例详解
2017/08/16 Javascript
PHP 实现一种多文件上传的方法
2017/09/20 Javascript
Vue Router的懒加载路径的解决方法
2018/06/21 Javascript
使用layer弹窗和layui表单实现新增功能
2018/08/09 Javascript
JavaScript 点击触发复制功能实例详解
2018/11/02 Javascript
微信小程序判断页面是否从其他页面返回的实例代码
2019/07/03 Javascript
electron-vue开发环境内存泄漏问题汇总
2019/10/10 Javascript
JS实现横向轮播图(中级版)
2020/01/18 Javascript
举例详解Python中循环语句的嵌套使用
2015/05/14 Python
Python用zip函数同时遍历多个迭代器示例详解
2016/11/14 Python
Python中的is和==比较两个对象的两种方法
2017/09/06 Python
python3 读写文件换行符的方法
2018/04/09 Python
Django框架视图函数设计示例
2019/07/29 Python
Django 限制访问频率的思路详解
2019/12/24 Python
仿CSDN Blog返回页面顶部功能实现原理及代码
2013/06/30 HTML / CSS
简历中自我评价分享
2013/10/09 职场文书
领导干部失职检讨书
2015/05/05 职场文书
高三化学教学反思
2016/02/22 职场文书
Redis高并发防止秒杀超卖实战源码解决方案
2021/11/01 Redis
Redis高并发缓存架构性能优化
2022/05/15 Redis