对Pytorch神经网络初始化kaiming分布详解


Posted in Python onAugust 18, 2019

函数的增益值

torch.nn.init.calculate_gain(nonlinearity, param=None)

提供了对非线性函数增益值的计算。

对Pytorch神经网络初始化kaiming分布详解

增益值gain是一个比例值,来调控输入数量级和输出数量级之间的关系。

fan_in和fan_out

pytorch计算fan_in和fan_out的源码


def _calculate_fan_in_and_fan_out(tensor):
 dimensions = tensor.ndimension()
 if dimensions < 2:
  raise ValueError("Fan in and fan out can not be computed 
  for tensor with fewer than 2 dimensions")

 if dimensions == 2: # Linear
  fan_in = tensor.size(1)
  fan_out = tensor.size(0)
 else:
  num_input_fmaps = tensor.size(1)
  num_output_fmaps = tensor.size(0)
  receptive_field_size = 1
  if tensor.dim() > 2:
   receptive_field_size = tensor[0][0].numel()
  fan_in = num_input_fmaps * receptive_field_size
  fan_out = num_output_fmaps * receptive_field_size

 return fan_in, fan_out

对Pytorch神经网络初始化kaiming分布详解

xavier分布

xavier分布解析:https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/

假设使用的是sigmoid函数。当权重值(值指的是绝对值)过小,输入值每经过网络层,方差都会减少,每一层的加权和很小,在sigmoid函数0附件的区域相当于线性函数,失去了DNN的非线性性。

当权重的值过大,输入值经过每一层后方差会迅速上升,每层的输出值将会很大,此时每层的梯度将会趋近于0.

xavier初始化可以使得输入值x x x<math><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math>x方差经过网络层后的输出值y y y<math><semantics><mrow><mi>y</mi></mrow><annotation encoding="application/x-tex">y</annotation></semantics></math>y方差不变。

(1)xavier的均匀分布

torch.nn.init.xavier_uniform_(tensor, gain=1)

对Pytorch神经网络初始化kaiming分布详解

也称为Glorot initialization。

>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

(2) xavier正态分布

torch.nn.init.xavier_normal_(tensor, gain=1)

对Pytorch神经网络初始化kaiming分布详解

也称为Glorot initialization。

kaiming分布

Xavier在tanh中表现的很好,但在Relu激活函数中表现的很差,所何凯明提出了针对于relu的初始化方法。pytorch默认使用kaiming正态分布初始化卷积层参数。

(1) kaiming均匀分布

torch.nn.init.kaiming_uniform_
 (tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

对Pytorch神经网络初始化kaiming分布详解

也被称为 He initialization。

a ? the negative slope of the rectifier used after this layer (0 for ReLU by default).激活函数的负斜率,

mode ? either ‘fan_in' (default) or ‘fan_out'. Choosing fan_in preserves the magnitude of the variance of the weights in the forward pass. Choosing fan_out preserves the magnitudes in the backwards

pass.默认为fan_in模式,fan_in可以保持前向传播的权重方差的数量级,fan_out可以保持反向传播的权重方差的数量级。

>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

(2) kaiming正态分布

torch.nn.init.kaiming_normal_
 (tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

对Pytorch神经网络初始化kaiming分布详解

也被称为 He initialization。

>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

以上这篇对Pytorch神经网络初始化kaiming分布详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python执行shell获取硬件参数写入mysql的方法
Dec 29 Python
列举Python中吸引人的一些特性
Apr 09 Python
简单介绍Python中的struct模块
Apr 28 Python
基于python 二维数组及画图的实例详解
Apr 03 Python
python使用Turtle库绘制动态钟表
Nov 19 Python
python实现对输入的密文加密
Mar 20 Python
如何使用Python脚本实现文件拷贝
Nov 20 Python
使用Python的Turtle绘制哆啦A梦实例
Nov 21 Python
Python如何实现邮件功能
May 27 Python
python opencv 实现读取、显示、写入图像的方法
Jun 08 Python
Python 没有main函数的原因
Jul 10 Python
Django model重写save方法及update踩坑详解
Jul 27 Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
You might like
咖啡豆要不要放冰箱的原因
2021/03/04 冲泡冲煮
别人整理的服务器变量:$_SERVER
2006/10/20 PHP
PHP 加密与解密的斗争
2009/04/17 PHP
火车头discuz6.1 完美采集的php接口文件
2009/09/13 PHP
6种php上传图片重命名的方法实例
2013/11/04 PHP
ThinkPHP学习笔记(一)ThinkPHP部署
2014/06/22 PHP
PHP模板引擎Smarty中的保留变量用法分析
2016/04/11 PHP
PHP使用内置函数生成图片的方法详解
2016/05/09 PHP
PHP有序表查找之二分查找(折半查找)算法示例
2018/02/09 PHP
PHP之认识(二)关于Traits的用法详解
2019/04/11 PHP
JS复制内容到剪切板的实例代码(兼容IE与火狐)
2013/11/19 Javascript
如何使用PHP+jQuery+MySQL实现异步加载ECharts地图数据(附源码下载)
2016/02/23 Javascript
深入理解js generator数据类型
2016/08/16 Javascript
jquery基本选择器匹配多个元素的实现方法
2016/09/05 Javascript
Bootstrap模态框禁用空白处点击关闭
2016/10/20 Javascript
vue自定义全局组件(自定义插件)的用法
2018/01/30 Javascript
React.js绑定this的5种方法(小结)
2018/06/05 Javascript
JavaScript简单实现的仿微博留言功能示例
2019/01/17 Javascript
mocha的时序规则讲解
2019/02/16 Javascript
Vue实现返回顶部按钮实例代码
2020/10/21 Javascript
[01:00:14]2018DOTA2亚洲邀请赛 4.6 淘汰赛 VP vs TNC 第三场
2018/04/10 DOTA
Python中 传递值 和 传递引用 的区别解析
2018/02/22 Python
利用Python如何将数据写到CSV文件中
2018/06/05 Python
python 返回列表中某个值的索引方法
2018/11/07 Python
利用pyinstaller打包exe文件的基本教程
2019/05/02 Python
简单了解Python生成器是什么
2019/07/02 Python
Python更换pip源方法过程解析
2020/05/19 Python
Keras之fit_generator与train_on_batch用法
2020/06/17 Python
HTML5实现可缩放时钟代码
2017/08/28 HTML / CSS
利达恒信公司.NET笔试题面试题
2016/03/05 面试题
AURALog面试题软件测试方面
2013/10/22 面试题
地理科学专业毕业生求职信
2013/10/15 职场文书
模具设计与制造专业求职信
2014/07/19 职场文书
Jupyter notebook 输出部分显示不全的解决方案
2021/04/24 Python
深入解析MySQL索引数据结构
2021/10/16 MySQL
动画《平凡职业成就世界最强》宣布制作OVA
2022/04/01 日漫