pytorch自定义初始化权重的方法


Posted in Python onAugust 17, 2019

在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。

核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
 
 
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
 
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)

附:Variable和Parameter的区别

Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。

Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。

另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.

参数:

parameter.data 得到tensor数据

parameter.requires_grad 默认为True, BP过程中会求导

Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。

我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。

以上这篇pytorch自定义初始化权重的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python3随机生成中文字符的实现方法
Nov 24 Python
浅谈Django REST Framework限速
Dec 12 Python
Python实现决策树C4.5算法的示例
May 30 Python
pyqt 实现在Widgets中显示图片和文字的方法
Jun 13 Python
Python3从零开始搭建一个语音对话机器人的实现
Aug 23 Python
Python爬虫爬取杭州24时温度并展示操作示例
Mar 27 Python
python中可以声明变量类型吗
Jun 18 Python
QT5 Designer 打不开的问题及解决方法
Aug 20 Python
python打包多类型文件的操作方法
Sep 21 Python
python中用ggplot绘制画图实例讲解
Jan 26 Python
python基础之匿名函数详解
Apr 21 Python
Python中的turtle画箭头,矩形,五角星
Mar 16 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
You might like
Yii2 rbac权限控制之菜单menu实例教程
2016/04/28 PHP
PHP strip_tags保留多个HTML标签的方法
2016/05/22 PHP
php实现将base64格式图片保存在指定目录的方法
2016/10/13 PHP
laravel 解决groupBy时出现的错误 isn't in Group By问题
2019/10/17 PHP
PHP实现笛卡尔积算法的实例讲解
2019/12/22 PHP
jquery判断元素的子元素是否存在的示例代码
2014/02/04 Javascript
JavaScript遍历table表格中的某行某列并打印其值
2014/07/08 Javascript
javascript判断移动端访问设备并解析对应CSS的方法
2015/02/05 Javascript
js简单实现表单中点击按钮动态增加输入框数量的方法
2015/08/18 Javascript
Asp.Net之JS生成分页条的方法
2016/11/23 Javascript
Vue.js:使用Vue-Router 2实现路由功能介绍
2017/02/22 Javascript
vue中element组件样式修改无效的解决方法
2018/02/03 Javascript
小程序实现人脸识别功能(百度ai)
2018/12/23 Javascript
JS前端广告拦截实现原理解析
2020/02/17 Javascript
Node.js API详解之 V8模块用法实例分析
2020/06/05 Javascript
如何利用python查找电脑文件
2018/04/27 Python
Python抽象和自定义类定义与用法示例
2018/08/23 Python
基于numpy中数组元素的切片复制方法
2018/11/15 Python
python Tkinter版学生管理系统
2019/02/20 Python
python实现雪花飘落效果实例讲解
2019/06/18 Python
python中的decimal类型转换实例详解
2019/06/26 Python
python做接口测试的必要性
2019/11/20 Python
浅谈python已知元素,获取元素索引(numpy,pandas)
2019/11/26 Python
python 图像的离散傅立叶变换实例
2020/01/02 Python
Python操作Jira库常用方法解析
2020/04/10 Python
Python流程控制语句的深入讲解
2020/06/15 Python
python 6行代码制作月历生成器
2020/09/18 Python
HTML5 video播放器全屏(fullScreen)方法实例
2015/04/24 HTML / CSS
AmazeUI的下载配置与Helloworld的实现
2020/08/19 HTML / CSS
世界上最具创新性的增强型知名运动品牌:Proviz
2018/04/03 全球购物
社区工作感言
2014/02/21 职场文书
党的群众路线教育实践活动心得体会(乡镇)
2014/11/03 职场文书
2014年乡镇工会工作总结
2014/12/02 职场文书
员工评语范文
2014/12/31 职场文书
python urllib库的使用详解
2021/04/13 Python
Python中文分词库jieba(结巴分词)详细使用介绍
2022/04/07 Python