Pytorch 实现权重初始化


Posted in Python onDecember 31, 2019

在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。

1、不初始化的效果

在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:

w = torch.Tensor(3,4)
print (w)

可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。

2、初始化的效果

PyTorch提供了多种参数初始化函数:

torch.nn.init.constant(tensor, val)
torch.nn.init.normal(tensor, mean=0, std=1)
torch.nn.init.xavier_uniform(tensor, gain=1)

等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init

注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。

让我们试试效果:

w = torch.Tensor(3,4)
torch.nn.init.normal_(w)
print (w)

3、初始化神经网络的参数

对神经网络的初始化往往放在模型的__init__()函数中,如下所示:

class Net(nn.Module):

def __init__(self, block, layers, num_classes=1000):
  self.inplanes = 64
  super(Net, self).__init__()
  ***
  *** #定义自己的网络层
  ***

  for m in self.modules():
    if isinstance(m, nn.Conv2d):
      n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
      m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
      m.weight.data.fill_(1)
      m.bias.data.zero_()

***
*** #定义后续的函数
***

也可以采取另一种方式:

定义一个权重初始化函数,如下:

def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv2d') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)
  elif classname.find('Linear') != -1:
    init.xavier_normal_(m.weight.data)
    init.constant_(m.bias.data, 0.0)

在模型声明时,调用初始化函数,初始化神经网络参数:

model = Net(*****)
model.apply(weights_init)

以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中optionParser模块的使用方法实例教程
Aug 29 Python
python dict.get()和dict['key']的区别详解
Jun 30 Python
python3下实现搜狗AI API的代码示例
Apr 10 Python
Python图像的增强处理操作示例【基于ImageEnhance类】
Jan 03 Python
python 根据时间来生成唯一的字符串方法
Jan 14 Python
Django之无名分组和有名分组的实现
Apr 16 Python
Python Multiprocessing多进程 使用tqdm显示进度条的实现
Aug 13 Python
Python多线程thread及模块使用实例
Apr 28 Python
python 写函数在一定条件下需要调用自身时的写法说明
Jun 01 Python
Pycharm打开已有项目配置python环境的方法
Jul 03 Python
python实现画图工具
Aug 27 Python
python APScheduler执行定时任务介绍
Apr 19 Python
pytorch 归一化与反归一化实例
Dec 31 #Python
Pytorch 数据加载与数据预处理方式
Dec 31 #Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 #Python
pytorch: Parameter 的数据结构实例
Dec 31 #Python
Python测试线程应用程序过程解析
Dec 31 #Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 #Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 #Python
You might like
PR值查询 | PageRank 查询
2006/12/20 PHP
PHP curl 抓取AJAX异步内容示例
2014/09/09 PHP
php中使用session防止用户非法登录后台的方法
2015/01/27 PHP
laravel-admin 中列表筛选方法
2019/10/03 PHP
javascript 模式设计之工厂模式学习心得
2010/04/27 Javascript
JS实现的表格行鼠标点击高亮效果代码
2015/11/27 Javascript
javascript运算符——逻辑运算符全面解析
2016/06/27 Javascript
Vue.js每天必学之计算属性computed与$watch
2016/09/05 Javascript
vue插件tab选项卡使用小结
2016/10/27 Javascript
require、backbone等重构手机图片查看器
2016/11/17 Javascript
javascript数组去重方法分析
2016/12/15 Javascript
json的结构与遍历方法实例分析
2017/04/25 Javascript
react+redux的升级版todoList的实现
2017/12/18 Javascript
js数组的基本使用总结
2021/01/18 Javascript
Python语言描述KNN算法与Kd树
2017/12/13 Python
Django 浅谈根据配置生成SQL语句的问题
2018/05/29 Python
python3监控CentOS磁盘空间脚本
2018/06/21 Python
Python音频操作工具PyAudio上手教程详解
2019/06/26 Python
Python中sys模块功能与用法实例详解
2020/02/26 Python
Python单例模式的四种创建方式实例解析
2020/03/04 Python
你需要学会的8个Python列表技巧
2020/06/24 Python
python实现从ftp上下载文件的实例方法
2020/07/19 Python
Python延迟绑定问题原理及解决方案
2020/08/04 Python
html5应用缓存_动力节点Java学院整理
2017/07/13 HTML / CSS
HTML5实现页面切换激活的PageVisibility API使用初探
2016/05/13 HTML / CSS
详解WebSocket跨域问题解决
2018/08/06 HTML / CSS
英国最专业的健身器材供应商之一:Best Gym Equipment
2017/12/22 全球购物
斯凯奇澳大利亚官网:SKECHERS澳大利亚
2018/03/31 全球购物
美国糖果店:Sugarfina
2019/02/21 全球购物
小学安全教育月活动总结
2014/07/07 职场文书
学习三严三实对照检查材料思想汇报
2014/09/22 职场文书
市场部岗位职责范本
2015/04/15 职场文书
工作感言一句话
2015/08/01 职场文书
党员观看《筑梦中国》心得体会
2016/01/18 职场文书
2016年少先队活动总结
2016/04/06 职场文书
Html5调用企业微信的实现
2021/04/16 HTML / CSS