pytorch中的weight-initilzation用法


Posted in Python onJune 24, 2020

pytorch中的权值初始化

官方论坛对weight-initilzation的讨论

torch.nn.Module.apply(fn)

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) ? function to be applied to each submodule
# Returns: self
# Return type: Module

例子:

def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02) 
  # m.weight.data是卷积核参数, m.bias.data是偏置项参数
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数
#-*-coding:utf-8-*-
import torch
from torch.autograd import Variable

# 对模型参数进行初始化
# 官方论坛链接:https://discuss.pytorch.org/t/weight-initilzation/157/3

# 方法一
# 单独定义一个weights_init函数,输入参数是m(torch.nn.module或者自己定义的继承nn.module的子类)
# 然后使用net.apply()进行参数初始化
# m.__class__.__name__ 获得nn.module的名字
# https://github.com/pytorch/examples/blob/master/dcgan/main.py#L90-L96
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

# function to be applied to each submodule

# 方法二
# 1. 使用net.modules()遍历模型中的网络层的类型 2. 对其中的m层的weigth.data(tensor)部分进行初始化操作
# Another initialization example from PyTorch Vision resnet implementation.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L112-L118
class ResNet(nn.Module):
 def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(ResNet, self).__init__()
  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
        bias=False)
  self.bn1 = nn.BatchNorm2d(64)
  self.relu = nn.ReLU(inplace=True)
  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  self.layer1 = self._make_layer(block, 64, layers[0])
  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  self.avgpool = nn.AvgPool2d(7, stride=1)
  self.fc = nn.Linear(512 * block.expansion, num_classes)
  # 权值参数初始化
  for m in self.modules():
   if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
   elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()

# 方法三
# 自己知道网络中参数的顺序和类型, 然后将参数依次读取出来,调用torch.nn.init中的方法进行初始化
net = AlexNet(2)
params = list(net.parameters()) # params依次为Conv2d参数和Bias参数
# 或者
conv1Params = list(net.conv1.parameters())
# 其中,conv1Params[0]表示卷积核参数, conv1Params[1]表示bias项参数
# 然后使用torch.nn.init中函数进行初始化
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.constant(tensor, 0)

# net.modules()迭代的返回: AlexNet,Sequential,Conv2d,ReLU,MaxPool2d,LRN,AvgPool3d....,Conv2d,...,Conv2d,...,Linear,
# 这里,只有Conv2d和Linear才有参数
# net.children()只返回实际存在的子模块: Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Sequential,Linear

# 附AlexNet的定义
class AlexNet(nn.Module):
 def __init__(self, num_classes = 2): # 默认为两类,猫和狗
#   super().__init__() # python3
  super(AlexNet, self).__init__()
  # 开始构建AlexNet网络模型,5层卷积,3层全连接层
  # 5层卷积层
  self.conv1 = nn.Sequential(
   nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv2 = nn.Sequential(
   nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, groups=2, padding=2),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2),
   LRN(local_size=5, bias=1, alpha=1e-4, beta=0.75, ACROSS_CHANNELS=True)
  )
  self.conv3 = nn.Sequential(
   nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv4 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
   nn.ReLU(inplace=True)
  )
  self.conv5 = nn.Sequential(
   nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
   nn.ReLU(inplace=True),
   nn.MaxPool2d(kernel_size=3, stride=2)
  )
  # 3层全连接层
  # 前向计算的时候,最开始输入需要进行view操作,将3D的tensor变为1D
  self.fc6 = nn.Sequential(
   nn.Linear(in_features=6*6*256, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc7 = nn.Sequential(
   nn.Linear(in_features=4096, out_features=4096),
   nn.ReLU(inplace=True),
   nn.Dropout()
  )
  self.fc8 = nn.Linear(in_features=4096, out_features=num_classes)

 def forward(self, x):
  x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
  x = x.view(-1, 6*6*256)
  x = self.fc8(self.fc7(self.fc6(x)))
  return x

补充知识:pytorch Load部分weights

我们从网上down下来的模型与我们的模型可能就存在一个层的差异,此时我们就需要重新训练所有的参数是不合理的。

因此我们可以加载相同的参数,而忽略不同的参数,代码如下:

pretrained_dict = torch.load(“model.pth”)
  model_dict = et.state_dict()
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
  model_dict.update(pretrained_dict)
  net.load_state_dict(model_dict)

以上这篇pytorch中的weight-initilzation用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python Django做网页
Nov 04 Python
Python打印scrapy蜘蛛抓取树结构的方法
Apr 08 Python
Python使用multiprocessing实现一个最简单的分布式作业调度系统
Mar 14 Python
Python搭建HTTP服务器和FTP服务器
Mar 09 Python
同时安装Python2 & Python3 cmd下版本自由选择的方法
Dec 09 Python
python中正则表达式的使用方法
Feb 25 Python
Python 实现递归法解决迷宫问题的示例代码
Jan 12 Python
新版Pycharm中Matplotlib不会弹出独立的显示窗口的问题
Jun 02 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
Jun 05 Python
python退出循环的方法
Jun 18 Python
如何用python绘制雷达图
Apr 24 Python
Python合并多张图片成PDF
Jun 09 Python
pytorch查看模型weight与grad方式
Jun 24 #Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
You might like
在线短消息收发的程序,不用数据库
2006/10/09 PHP
PHP 常见郁闷问题答解
2006/11/25 PHP
php使用curl抓取qq空间的访客信息示例
2014/02/28 PHP
PHP获取网站中各文章的第一张图片的代码示例
2016/05/20 PHP
ThinkPHP 模板引擎使用详解
2017/05/07 PHP
JavaScript入门教程 Cookies
2009/01/31 Javascript
js 鼠标点击事件及其它捕获
2009/06/04 Javascript
Array.prototype 的泛型应用分析
2010/04/30 Javascript
jquery键盘事件介绍
2011/01/31 Javascript
jQuery实现菜单感应鼠标滑动动画效果的方法
2015/02/28 Javascript
javascript中一些util方法汇总
2015/06/10 Javascript
JavaScript实现的经典文件树菜单效果
2015/09/08 Javascript
jQuery+PHP实现可编辑表格字段内容并实时保存
2015/10/09 Javascript
vuejs在解析时出现闪烁的原因及防止闪烁的方法
2016/09/19 Javascript
微信小程序 页面跳转传参详解
2016/10/28 Javascript
JQuery 进入页面默认给已赋值的复选框打钩
2017/03/23 jQuery
Nodejs回调加超时限制两种实现方法
2017/06/09 NodeJs
JavaScript该如何学习 怎样轻松学习JavaScript
2017/06/12 Javascript
JavaScript折半查找(二分查找)算法原理与实现方法示例
2018/08/06 Javascript
浅谈vue引用静态资源需要注意的事项
2018/09/28 Javascript
小程序实现投票进度条
2019/11/20 Javascript
python虚拟环境 virtualenv的简单使用
2020/01/21 Javascript
jQuery使用ajax传递json对象到服务端及contentType的用法示例
2020/03/12 jQuery
JavaScript数组排序的六种常见算法总结
2020/08/18 Javascript
[14:24]Optic Gaming vs PSG LGD BO3
2018/06/07 DOTA
Python获取远程文件大小的函数代码分享
2014/05/13 Python
Python实现文件内容批量追加的方法示例
2017/08/29 Python
Python数据可视化处理库PyEcharts柱状图,饼图,线性图,词云图常用实例详解
2020/02/10 Python
Python ini文件常用操作方法解析
2020/04/26 Python
python time()的实例用法
2020/11/03 Python
女大学生自我鉴定
2013/12/09 职场文书
退伍老兵事迹材料
2014/01/31 职场文书
团队拓展活动总结
2014/08/27 职场文书
个人四风问题整改措施
2014/10/24 职场文书
好媳妇事迹材料
2014/12/24 职场文书
中秋节寄语2015
2015/03/24 职场文书