Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python删除Java源文件中全部注释的实现方法
Aug 30 Python
Python实现变量数值交换及判断数组是否含有某个元素的方法
Sep 18 Python
python 读取Linux服务器上的文件方法
Dec 27 Python
Python空间数据处理之GDAL读写遥感图像
Aug 01 Python
Python实现图片批量加入水印代码实例
Nov 30 Python
Python timer定时器两种常用方法解析
Jan 20 Python
Python使用pyyaml模块处理yaml数据
Apr 14 Python
python实现自动清理重复文件
Aug 24 Python
详解python命令提示符窗口下如何运行python脚本
Sep 11 Python
python之pygame模块实现飞机大战完整代码
Nov 29 Python
Python Process创建进程的2种方法详解
Jan 25 Python
python入门学习关于for else的特殊特性讲解
Nov 20 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
php网页后退不再出现过期
2007/03/08 PHP
escape unescape的php下的实现方法
2007/04/27 PHP
php empty() 检查一个变量是否为空
2011/11/10 PHP
php采集文章中的图片获取替换到本地(实现代码)
2013/07/08 PHP
浅析php中三个等号(===)和两个等号(==)的区别
2013/08/06 PHP
php使用strtotime和date函数判断日期是否有效代码分享
2013/12/25 PHP
CI框架学习笔记(二) -入口文件index.php
2014/10/27 PHP
php实现的简易扫雷游戏实例
2015/07/09 PHP
php生成图片验证码-附五种验证码
2015/08/19 PHP
php的laravel框架快速集成微信登录的方法
2016/12/12 PHP
jquery创建div 实现代码
2009/04/27 Javascript
基于jquery中children()与find()的区别介绍
2013/04/26 Javascript
parentElement,srcElement的使用小结
2014/01/13 Javascript
vue.js初学入门教程(1)
2016/11/03 Javascript
jQuery实现的无缝广告图片左右滚动功能详解
2016/12/24 Javascript
Javascript中字符串和数字的操作方法整理
2017/01/22 Javascript
jQuery插件autocomplete使用详解
2017/02/04 Javascript
Swiper自定义分页器使用详解
2017/12/28 Javascript
vue项目添加多页面配置的步骤详解
2019/05/22 Javascript
小程序实现横向滑动日历效果
2019/10/21 Javascript
vue实现随机验证码功能(完整代码)
2019/12/10 Javascript
javascript局部自定义鼠标右键菜单
2020/12/08 Javascript
Vue——前端生成二维码的示例
2020/12/19 Vue.js
[02:49]DOTA2完美大师赛首日观众采访
2017/11/23 DOTA
python数据结构之二叉树的遍历实例
2014/04/29 Python
TensorFlow实现MLP多层感知机模型
2018/03/09 Python
Java与Python两大幸存者谁更胜一筹呢
2018/04/12 Python
Python matplotlib修改默认字体的操作
2020/03/05 Python
吃空饷专项治理工作实施方案
2014/03/04 职场文书
就业协议书怎么填
2014/04/11 职场文书
政府法律服务方案
2014/06/14 职场文书
村长党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
经理岗位职责范本
2015/04/15 职场文书
员工离职证明范本
2015/06/12 职场文书
详解MySQL数据库千万级数据查询和存储
2021/05/18 MySQL
gojs实现蚂蚁线动画效果
2022/02/18 Javascript