Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Socket编程入门教程
Jul 11 Python
Python中optparse模块使用浅析
Jan 01 Python
TF-IDF算法解析与Python实现方法详解
Nov 16 Python
Python实现返回数组中第i小元素的方法示例
Dec 04 Python
Python3实现的画图及加载图片动画效果示例
Jan 19 Python
python函数式编程学习之yield表达式形式详解
Mar 25 Python
用Python shell简化开发
Aug 08 Python
Python模拟自动存取款机的查询、存取款、修改密码等操作
Sep 02 Python
将python文件打包exe独立运行程序方法详解
Feb 12 Python
Python通过socketserver处理多个链接
Mar 18 Python
python把一个字符串切开的实例方法
Sep 27 Python
python实现测试工具(二)——简单的ui测试工具
Oct 19 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
星际争霸教主Flash的ID由来:你永远不会知道他之前的ID是www!
2019/01/18 星际争霸
destoon实现资讯信息前面调用它所属分类的方法
2014/07/15 PHP
PHP缓存集成库phpFastCache用法
2014/12/15 PHP
javascript 得到变量类型的函数
2010/05/19 Javascript
javascript中的继承实例代码
2011/04/27 Javascript
jquery remove方法应用详解
2012/11/22 Javascript
jquery的ajax请求全面了解
2013/03/20 Javascript
基于jQuery实现下拉框
2014/11/24 Javascript
jQuery插件slicebox实现3D动画图片轮播切换特效
2015/04/12 Javascript
jQuery实现点击后标记当前菜单位置(背景高亮菜单)效果
2015/08/22 Javascript
纯HTML5制作围住神经猫游戏-附源码下载
2015/08/23 Javascript
jQuery获取checkboxlist的value值的方法
2015/09/27 Javascript
jQuery实现的分子运动小球碰撞效果
2016/01/27 Javascript
基于AngularJS+HTML+Groovy实现登录功能
2016/02/17 Javascript
js 显示日期时间的实例(时间过一秒加1)
2017/10/25 Javascript
解决vue keep-alive 数据更新的问题
2018/09/21 Javascript
深入了解js原型模式
2019/05/30 Javascript
JavaScript冒泡算法原理与实现方法深入理解
2020/06/04 Javascript
写了个监控nginx进程的Python脚本
2012/05/10 Python
Python中用函数作为返回值和实现闭包的教程
2015/04/27 Python
使用Python制作获取网站目录的图形化程序
2015/05/04 Python
Python使用内置json模块解析json格式数据的方法
2017/07/20 Python
Django Channels 实现点对点实时聊天和消息推送功能
2019/07/17 Python
Sofft鞋官网:世界知名鞋类品牌
2017/03/28 全球购物
全球立体声:World Wide Stereo
2018/09/29 全球购物
印度在线购买电子产品网站:Croma
2020/01/02 全球购物
介绍一下你对SOA的认识
2016/04/24 面试题
Ajax实现页面无刷新留言效果
2021/03/24 Javascript
财务部副经理岗位职责
2014/03/14 职场文书
产品售后服务承诺书
2014/05/21 职场文书
保护环境的标语
2014/06/09 职场文书
四风问题个人对照检查材料
2014/09/26 职场文书
2015年全国助残日活动方案
2015/05/04 职场文书
投资申请报告
2015/05/19 职场文书
2015年财政局工作总结
2015/05/21 职场文书
2016领导干部廉洁从政心得体会
2016/01/19 职场文书