Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python文件操作类操作实例详解
Jul 11 Python
Python下使用Psyco模块优化运行速度
Apr 05 Python
Python中转换角度为弧度的radians()方法
May 18 Python
Python随机数random模块使用指南
Sep 09 Python
Python排序搜索基本算法之选择排序实例分析
Dec 09 Python
Python基于列表模拟堆栈和队列功能示例
Jan 05 Python
Python断言assert的用法代码解析
Feb 03 Python
python3+PyQt5+Qt Designer实现堆叠窗口部件
Apr 20 Python
python使用Matplotlib画饼图
Sep 25 Python
基于MATLAB和Python实现MFCC特征参数提取
Aug 13 Python
python集成开发环境配置(pycharm)
Feb 14 Python
详解Python流程控制语句
Oct 28 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php记录代码执行时间(实现代码)
2013/07/05 PHP
Gambit vs CL BO3 第三场 2.13
2021/03/10 DOTA
参考:关于Javascript中实现暂停的几篇文章
2007/03/04 Javascript
js使用for循环及if语句判断多个一样的name
2014/09/09 Javascript
Jquery1.9.1源码分析系列(十五)动画处理之外篇
2015/12/04 Javascript
关于安卓手机微信浏览器中使用XMLHttpRequest 2上传图片显示字节数为0的解决办法
2016/05/17 Javascript
jQuery插件实现图片轮播特效
2016/06/16 Javascript
JavaScript常用代码书写规范的超全面总结
2016/09/11 Javascript
前端框架Vue.js中Directive知识详解
2016/09/12 Javascript
JS正则替换掉小括号及内容的方法
2016/11/29 Javascript
node.js实现登录注册页面
2017/04/08 Javascript
BootStrap实现文件上传并带有进度条效果
2017/09/11 Javascript
JS非行间样式获取函数的实例代码
2018/06/05 Javascript
Vue源码学习之关于对Array的数据侦听实现
2019/04/23 Javascript
简单了解JavaScript中常见的反模式
2019/06/21 Javascript
jQuery高级编程之js对象、json与ajax用法实例分析
2019/11/01 jQuery
Vue实现剪切板图片压缩功能
2020/02/04 Javascript
小程序实现简单语音聊天的示例代码
2020/07/24 Javascript
解决vue动态下拉菜单 有数据未反应的问题
2020/08/06 Javascript
[00:14]PWL:老朋友Mushi拍VLOG与中国玩家问好
2020/11/04 DOTA
python机器学习之神经网络(二)
2017/12/20 Python
Python使用pyodbc访问数据库操作方法详解
2018/07/05 Python
Python中的四种交换数值的方法解析
2019/11/18 Python
python找出列表中大于某个阈值的数据段示例
2019/11/24 Python
Django中FilePathField字段的用法
2020/05/21 Python
Python性能测试工具Locust安装及使用
2020/12/01 Python
html5 canvas绘制矩形和圆形的实例代码
2016/06/16 HTML / CSS
希尔顿酒店官方网站:Hilton Hotels
2017/06/01 全球购物
摄影实习自我鉴定
2013/09/20 职场文书
机关办公室岗位职责
2014/04/16 职场文书
学校食堂食品安全责任书
2014/07/28 职场文书
优秀党员自我评价范文
2014/09/15 职场文书
毕业典礼致辞
2015/07/29 职场文书
原来实习报告是这样写的呀!
2019/07/03 职场文书
OpenCV-Python实现油画效果的实例
2021/06/08 Python
Java版 单机五子棋
2022/05/04 Java/Android