Pytorch - TORCH.NN.INIT 参数初始化的操作


Posted in Python onFebruary 27, 2021

路径:

https://pytorch.org/docs/master/nn.init.html#nn-init-doc

初始化函数:torch.nn.init

# -*- coding: utf-8 -*-
"""
Created on 2019
@author: fancp
"""
import torch 
import torch.nn as nn
w = torch.empty(3,5)
#1.均匀分布 - u(a,b)
#torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
print(nn.init.uniform_(w))
# =============================================================================
# tensor([[0.9160, 0.1832, 0.5278, 0.5480, 0.6754],
#     [0.9509, 0.8325, 0.9149, 0.8192, 0.9950],
#     [0.4847, 0.4148, 0.8161, 0.0948, 0.3787]])
# =============================================================================
#2.正态分布 - N(mean, std)
#torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
print(nn.init.normal_(w))
# =============================================================================
# tensor([[ 0.4388, 0.3083, -0.6803, -1.1476, -0.6084],
#     [ 0.5148, -0.2876, -1.2222, 0.6990, -0.1595],
#     [-2.0834, -1.6288, 0.5057, -0.5754, 0.3052]])
# =============================================================================
#3.常数 - 固定值 val
#torch.nn.init.constant_(tensor, val)
print(nn.init.constant_(w, 0.3))
# =============================================================================
# tensor([[0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
#     [0.3000, 0.3000, 0.3000, 0.3000, 0.3000]])
# =============================================================================
#4.全1分布
#torch.nn.init.ones_(tensor)
print(nn.init.ones_(w))
# =============================================================================
# tensor([[1., 1., 1., 1., 1.],
#     [1., 1., 1., 1., 1.],
#     [1., 1., 1., 1., 1.]])
# =============================================================================
#5.全0分布
#torch.nn.init.zeros_(tensor)
print(nn.init.zeros_(w))
# =============================================================================
# tensor([[0., 0., 0., 0., 0.],
#     [0., 0., 0., 0., 0.],
#     [0., 0., 0., 0., 0.]])
# =============================================================================
#6.对角线为 1,其它为 0
#torch.nn.init.eye_(tensor)
print(nn.init.eye_(w))
# =============================================================================
# tensor([[1., 0., 0., 0., 0.],
#     [0., 1., 0., 0., 0.],
#     [0., 0., 1., 0., 0.]])
# =============================================================================
#7.xavier_uniform 初始化
#torch.nn.init.xavier_uniform_(tensor, gain=1.0)
#From - Understanding the difficulty of training deep feedforward neural networks - Bengio 2010
print(nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')))
# =============================================================================
# tensor([[-0.1270, 0.3963, 0.9531, -0.2949, 0.8294],
#     [-0.9759, -0.6335, 0.9299, -1.0988, -0.1496],
#     [-0.7224, 0.2181, -1.1219, 0.8629, -0.8825]])
# =============================================================================
#8.xavier_normal 初始化
#torch.nn.init.xavier_normal_(tensor, gain=1.0)
print(nn.init.xavier_normal_(w))
# =============================================================================
# tensor([[ 1.0463, 0.1275, -0.3752, 0.1858, 1.1008],
#     [-0.5560, 0.2837, 0.1000, -0.5835, 0.7886],
#     [-0.2417, 0.1763, -0.7495, 0.4677, -0.1185]])
# =============================================================================
#9.kaiming_uniform 初始化
#torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
#From - Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - HeKaiming 2015
print(nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.7712, 0.9344, 0.8304, 0.2367, 0.0478],
#     [-0.6139, -0.3916, -0.0835, 0.5975, 0.1717],
#     [ 0.3197, -0.9825, -0.5380, -1.0033, -0.3701]])
# =============================================================================
#10.kaiming_normal 初始化
#torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
print(nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu'))
# =============================================================================
# tensor([[-0.0210, 0.5532, -0.8647, 0.9813, 0.0466],
#     [ 0.7713, -1.0418, 0.7264, 0.5547, 0.7403],
#     [-0.8471, -1.7371, 1.3333, 0.0395, 1.0787]])
# =============================================================================
#11.正交矩阵 - (semi)orthogonal matrix
#torch.nn.init.orthogonal_(tensor, gain=1)
#From - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe 2013
print(nn.init.orthogonal_(w))
# =============================================================================
# tensor([[-0.0346, -0.7607, -0.0428, 0.4771, 0.4366],
#     [-0.0412, -0.0836, 0.9847, 0.0703, -0.1293],
#     [-0.6639, 0.4551, 0.0731, 0.1674, 0.5646]])
# =============================================================================
#12.稀疏矩阵 - sparse matrix 
#torch.nn.init.sparse_(tensor, sparsity, std=0.01)
#From - Deep learning via Hessian-free optimization - Martens 2010
print(nn.init.sparse_(w, sparsity=0.1))
# =============================================================================
# tensor([[ 0.0000, 0.0000, -0.0077, 0.0000, -0.0046],
#     [ 0.0152, 0.0030, 0.0000, -0.0029, 0.0005],
#     [ 0.0199, 0.0132, -0.0088, 0.0060, 0.0000]])
# =============================================================================

补充:【pytorch参数初始化】 pytorch默认参数初始化以及自定义参数初始化

本文用两个问题来引入

1.pytorch自定义网络结构不进行参数初始化会怎样,参数值是随机的吗?

2.如何自定义参数初始化?

先回答第一个问题

在pytorch中,有自己默认初始化参数方式,所以在你定义好网络结构以后,不进行参数初始化也是可以的。

1.Conv2d继承自_ConvNd,在_ConvNd中,可以看到默认参数就是进行初始化的,如下图所示

Pytorch - TORCH.NN.INIT 参数初始化的操作

Pytorch - TORCH.NN.INIT 参数初始化的操作

2.torch.nn.BatchNorm2d也一样有默认初始化的方式

Pytorch - TORCH.NN.INIT 参数初始化的操作

3.torch.nn.Linear也如此

Pytorch - TORCH.NN.INIT 参数初始化的操作

现在来回答第二个问题。

pytorch中对神经网络模型中的参数进行初始化方法如下:

from torch.nn import init
#define the initial function to init the layer's parameters for the network
def weigth_init(m):
  if isinstance(m, nn.Conv2d):
    init.xavier_uniform_(m.weight.data)
    init.constant_(m.bias.data,0.1)
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    m.weight.data.normal_(0,0.01)
    m.bias.data.zero_()

首先定义了一个初始化函数,接着进行调用就ok了,不过要先把网络模型实例化:

#Define Network
  model = Net(args.input_channel,args.output_channel)
  model.apply(weigth_init)

此上就完成了对模型中训练参数的初始化。

在知乎上也有看到一个类似的版本,也相应的贴上来作为参考了:

def initNetParams(net):
  '''Init net parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.xavier_uniform(m.weight)
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0) 
initNetParams(net)

再说一下关于模型的保存及加载

1.保存有两种方式,第一种是保存模型的整个结构信息和参数,第二种是只保存模型的参数

#保存整个网络模型及参数
 torch.save(net, 'net.pkl') 
 
 #仅保存模型参数
 torch.save(net.state_dict(), 'net_params.pkl')

2.加载对应保存的两种网络

# 保存和加载整个模型 
torch.save(model_object, 'model.pth') 
model = torch.load('model.pth') 
 
# 仅保存和加载模型参数 
torch.save(model_object.state_dict(), 'params.pth') 
model_object.load_state_dict(torch.load('params.pth'))

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
在Linux中通过Python脚本访问mdb数据库的方法
May 06 Python
讲解Python中运算符使用时的优先级
May 14 Python
Python向日志输出中添加上下文信息
May 24 Python
将pandas.dataframe的数据写入到文件中的方法
Dec 07 Python
python实现批量视频分帧、保存视频帧
May 31 Python
python PIL和CV对 图片的读取,显示,裁剪,保存实现方法
Aug 07 Python
python 实现保存最新的三份文件,其余的都删掉
Dec 22 Python
keras中模型训练class_weight,sample_weight区别说明
May 23 Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 Python
python 如何引入协程和原理分析
Nov 30 Python
python lambda的使用详解
Feb 26 Python
pandas数值排序的实现实例
Jul 25 Python
python FTP编程基础入门
Feb 27 #Python
python SOCKET编程基础入门
Feb 27 #Python
python 对xml解析的示例
Feb 27 #Python
python如何发送带有附件、正文为HTML的邮件
Feb 27 #Python
pytorch __init__、forward与__call__的用法小结
Feb 27 #Python
python 实现有道翻译功能
Feb 26 #Python
Python爬取酷狗MP3音频的步骤
Feb 26 #Python
You might like
一键删除顽固的空文件夹 软件下载
2007/01/26 PHP
PHP 多维数组的排序问题 根据二维数组中某个项排序
2011/11/09 PHP
使用PHP Socket 编程模拟Http post和get请求
2014/11/25 PHP
Yii2实现增删改查后留在当前页的方法详解
2017/01/13 PHP
PHP+iframe模拟Ajax上传文件功能示例
2019/07/02 PHP
共享自己写一个框架DreamScript
2007/01/20 Javascript
javascript中全局对象的parseInt()方法使用介绍
2013/12/19 Javascript
Javascript 遍历页面text控件详解
2014/01/06 Javascript
jquery.cookie.js使用指南
2015/01/05 Javascript
js从外部获取图片的实现方法
2016/08/05 Javascript
纯js实现手风琴效果代码
2020/04/17 Javascript
jquery移除了live()、die(),新版事件绑定on()、off()的方法
2016/10/26 Javascript
深入理解JavaScript中的for循环
2017/02/07 Javascript
详解VueJs异步动态加载块
2017/03/09 Javascript
angular.js中解决跨域问题的三种方式
2017/07/12 Javascript
关于express与koa的使用对比详解
2018/01/25 Javascript
解决vue-cli webpack打包开启Gzip 报错问题
2019/07/24 Javascript
jQuery表单选择器用法详解
2019/08/22 jQuery
浅谈MySQL中的触发器
2015/05/05 Python
最近Python有点火? 给你7个学习它的理由!
2017/06/26 Python
python 连接sqlite及简单操作
2017/06/30 Python
python和shell获取文本内容的方法
2018/06/05 Python
python中的插值 scipy-interp的实现代码
2018/07/23 Python
python实现自动网页截图并裁剪图片
2018/07/30 Python
Python 实现还原已撤回的微信消息
2019/06/18 Python
详解利用python+opencv识别图片中的圆形(霍夫变换)
2019/07/01 Python
python return逻辑判断表达式实现解析
2019/12/02 Python
matplotlib 三维图表绘制方法简介
2020/09/20 Python
爱尔兰家电数码商城:Currys PC World爱尔兰
2016/07/23 全球购物
捷克家具销售网站:SCONTO Nábytek
2020/01/02 全球购物
营销总经理岗位职责
2014/02/02 职场文书
领导班子对照检查剖析材料
2014/10/13 职场文书
2014年检验员工作总结
2014/11/19 职场文书
2015年销售内勤工作总结
2015/04/27 职场文书
地心历险记观后感
2015/06/15 职场文书
Oracle 临时表空间SQL语句的实现
2021/09/25 Oracle