pytorch自定义初始化权重的方法


Posted in Python onAugust 17, 2019

在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。

核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
 
 
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
 
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)

附:Variable和Parameter的区别

Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。

Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。

另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.

参数:

parameter.data 得到tensor数据

parameter.requires_grad 默认为True, BP过程中会求导

Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。

以上这篇pytorch自定义初始化权重的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
MAC中PyCharm设置python3解释器
Dec 15 Python
Python2.7+pytesser实现简单验证码的识别方法
Dec 29 Python
30秒轻松实现TensorFlow物体检测
Mar 14 Python
Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例
Jul 27 Python
Python字符串逆序输出的实例讲解
Feb 16 Python
python实现海螺图片的方法示例
May 12 Python
如何在Python中实现goto语句的方法
May 18 Python
python图像和办公文档处理总结
May 28 Python
Python中输入和输出(打印)数据实例方法
Oct 13 Python
python opencv实现信用卡的数字识别
Jan 12 Python
PyTorch的SoftMax交叉熵损失和梯度用法
Jan 15 Python
教你如何使用Python下载B站视频的详细教程
Apr 29 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
You might like
web目录下不应该存在多余的程序(安全考虑)
2012/05/09 PHP
解决PHP mysql_query执行超时(Fatal error: Maximum execution time …)
2013/07/03 PHP
php实现的zip文件内容比较类
2014/09/24 PHP
PHPUnit安装及使用示例
2014/10/29 PHP
php实现删除空目录的方法
2015/03/16 PHP
Nigma vs Alliance BO5 第三场2.14
2021/03/10 DOTA
脚本安需导入(装载)的三种模式的对比
2007/06/24 Javascript
用于判断用户注册时,密码强度的JS代码
2009/01/01 Javascript
查找页面中所有类为test的结点的方法
2014/03/28 Javascript
angular简介和其特点介绍
2015/01/29 Javascript
jQuery取消特定的click事件
2016/02/29 Javascript
Validform验证时可以为空否则按照指定格式验证
2017/10/20 Javascript
CentOS环境中MySQL修改root密码方法
2018/01/07 Javascript
JavaScript实现一个简易的计算器实例代码
2018/05/10 Javascript
JS伪继承prototype实现方法示例
2018/06/20 Javascript
微信小程序引用iconfont图标的方法
2018/10/22 Javascript
使用NestJS开发Node.js应用的方法
2018/12/03 Javascript
Nuxt页面级缓存的实现
2020/03/09 Javascript
vue 项目软键盘回车触发搜索事件
2020/09/09 Javascript
Python常见数据结构详解
2014/07/24 Python
Python3 socket同步通信简单示例
2017/06/07 Python
matplotlib绘制动画代码示例
2018/01/02 Python
Pandas把dataframe或series转换成list的方法
2020/06/14 Python
Python+logging输出到屏幕将log日志写入文件
2020/11/11 Python
Aerosoles爱柔仕官网:美国舒软女鞋品牌
2017/07/17 全球购物
GANT英国官方网上商店:甘特衬衫
2018/02/06 全球购物
Set里的元素是不能重复的,那么用什么方法来区分重复与否呢? 是用==还是equals()? 它们有何区别?用contains来区分是否有重复的对象。还是都不用
2013/07/30 面试题
小学评语大全
2014/04/22 职场文书
幼儿园教研活动总结
2014/04/30 职场文书
建设单位项目负责人任命书
2014/06/06 职场文书
2014年学生会干事工作总结
2014/11/07 职场文书
奖学金感谢信
2015/01/21 职场文书
2015年公务员试用期工作总结
2015/05/28 职场文书
幼儿园中班班级总结
2015/08/10 职场文书
仅仅使用 HTML/CSS 实现各类进度条的方式汇总
2021/11/11 HTML / CSS
教你nginx跳转配置的四种方式
2022/07/07 Servers