Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python遍历目录并批量更换文件名和目录名的方法
Sep 19 Python
Ubuntu下创建虚拟独立的Python环境全过程
Feb 10 Python
sublime python3 输入换行不结束的方法
Apr 19 Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 Python
python数据批量写入ScrolledText的优化方法
Oct 11 Python
python中for循环输出列表索引与对应的值方法
Nov 07 Python
Python pandas DataFrame操作的实现代码
Jun 21 Python
django创建最简单HTML页面跳转方法
Aug 16 Python
pandas factorize实现将字符串特征转化为数字特征
Dec 19 Python
Java多线程实现四种方式原理详解
Jun 02 Python
Python实现查询剪贴板自动匹配信息的思路详解
Jul 09 Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
PHP中的函数-- foreach()的用法详解
2013/06/24 PHP
php表单请求获得数据求和示例
2014/05/15 PHP
PHP中echo和print的区别
2014/08/28 PHP
php 魔术方法详解
2014/11/11 PHP
event.srcElement+表格应用
2006/08/29 Javascript
js操作时间(年-月-日 时-分-秒 星期几)
2010/06/20 Javascript
scrollWidth,clientWidth,offsetWidth的区别
2015/01/13 Javascript
JQUERY简单按钮轮换选中效果实现方法
2015/05/07 Javascript
windows下安装nodejs及框架express
2015/08/07 NodeJs
Highcharts入门之简介
2016/08/02 Javascript
js实现浏览器倒计时跳转页面效果
2016/08/12 Javascript
EasyUI创建对话框的两种方式
2016/08/23 Javascript
ES6使用let命令更简单的实现块级作用域实例分析
2017/03/31 Javascript
jQuery自定义元素右键点击事件(实现案例)
2017/04/28 jQuery
node-sass安装失败的原因与解决方法
2017/09/04 Javascript
Node.js中你不可不精的Stream(流)
2018/06/08 Javascript
前端插件之Bootstrap Dual Listbox使用教程
2019/07/23 Javascript
[47:02]2018DOTA2亚洲邀请赛3月29日 小组赛B组 VP VS paiN
2018/03/30 DOTA
[01:31:22]DOTA2-DPC中国联赛定级赛 LBZS vs Magma BO3第二场 1月10日
2021/03/11 DOTA
Python中的默认参数详解
2015/06/24 Python
Python 爬虫学习笔记之正则表达式
2016/09/21 Python
python 日志模块 日志等级设置失效的解决方案
2020/05/26 Python
python和php学习哪个更有发展
2020/06/17 Python
Python ckeditor富文本编辑器代码实例解析
2020/06/22 Python
python性能测试工具locust的使用
2020/12/28 Python
html5录音功能实战示例
2019/03/25 HTML / CSS
Chicco婴儿用品美国官网:汽车座椅、婴儿推车、高脚椅等
2018/11/05 全球购物
Cinque网上商店:德国服装品牌
2019/03/17 全球购物
Linux开机引导的步骤是什么
2014/02/26 面试题
客户经理岗位职责
2013/12/08 职场文书
中学生爱国演讲稿
2014/09/05 职场文书
软环境建设心得体会
2014/09/09 职场文书
小学趣味运动会加油稿
2014/09/25 职场文书
学历证明范文
2015/06/16 职场文书
网络新闻该怎么写?这些写作技巧你都知道吗?
2019/08/26 职场文书
Golang 使用Map实现去重与set的功能操作
2021/04/29 Golang