Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 读取.csv文件数据到数组(矩阵)的实例讲解
Jun 14 Python
python的中异常处理机制
Aug 30 Python
对pandas中iloc,loc取数据差别及按条件取值的方法详解
Nov 06 Python
Python函数装饰器实现方法详解
Dec 22 Python
python实现批量文件重命名
Oct 31 Python
Python 调用有道翻译接口实现翻译
Mar 02 Python
Python 读取WAV音频文件 画频谱的实例
Mar 14 Python
基于plt.title无法显示中文的快速解决
May 16 Python
Python用K-means聚类算法进行客户分群的实现
Aug 23 Python
Python绘制组合图的示例
Sep 18 Python
scrapy在python爬虫中搭建出错的解决方法
Nov 22 Python
解决pycharm导入numpy包的和使用时报错:RuntimeError: The current Numpy installation (‘D:\\python3.6\\lib\\site-packa的问题
Dec 08 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php 信息采集程序代码
2009/03/17 PHP
php select,radio和checkbox默认选择的实现方法
2010/05/15 PHP
对于PHP 5.4 你必须要知道的
2013/08/07 PHP
php自动更新版权信息显示的方法
2015/06/19 PHP
php pdo oracle中文乱码的快速解决方法
2016/05/16 PHP
限制文本字节数js代码
2007/03/06 Javascript
jquery 操作两个select实现值之间的互相传递
2014/03/07 Javascript
JavaScript打印网页指定区域的例子
2014/05/03 Javascript
js实现网页自动刷新可制作节日倒计时效果
2014/05/27 Javascript
SuperSlide标签切换、焦点图多种组合插件
2015/03/14 Javascript
超赞的动手创建JavaScript框架的详细教程
2015/06/30 Javascript
window.location.hash知识汇总
2015/11/09 Javascript
JavaScript中SetInterval与setTimeout的用法详解
2015/11/10 Javascript
JavaScript编程学习技巧汇总
2016/02/21 Javascript
jQuery通用的全局遍历方法$.each()用法实例
2016/07/04 Javascript
JS仿hao123导航页面图片轮播效果
2016/09/01 Javascript
jQuery属性选择器用法示例
2016/09/09 Javascript
Vue 实现双向绑定的四种方法
2018/03/16 Javascript
vue中使用better-scroll实现滑动效果及注意事项
2018/11/15 Javascript
微信小程序 生成携带参数的二维码
2019/10/23 Javascript
jQuery冲突问题解决方法
2021/01/19 jQuery
在Python中用get()方法获取字典键值的教程
2015/05/21 Python
Python如何实现MySQL实例初始化详解
2017/11/06 Python
Python生成任意范围任意精度的随机数方法
2018/04/09 Python
pytorch 自定义数据集加载方法
2019/08/18 Python
Python 使用Opencv实现目标检测与识别的示例代码
2020/09/08 Python
5分钟让你掌握css3阴影、倒影、渐变小技巧(小编推荐)
2016/08/15 HTML / CSS
美国在线购买空气净化器、除湿器、加湿器网站:AllergyBuyersClub
2021/03/16 全球购物
AssertionError 跟一下那个类是 “is – a”的关系
2012/02/21 面试题
运动会方阵解说词
2014/02/12 职场文书
广告艺术设计专业自荐书
2014/07/08 职场文书
单位实习工作证明怎么写
2014/11/02 职场文书
2014年园林绿化工作总结
2014/12/11 职场文书
干部考察材料范文
2014/12/24 职场文书
稽核岗位职责范本
2015/04/13 职场文书
社区干部培训心得体会
2016/01/06 职场文书