Pytorch加载部分预训练模型的参数实例


Posted in Python onAugust 18, 2019

前言

自从从深度学习框架caffe转到Pytorch之后,感觉Pytorch的优点妙不可言,各种设计简洁,方便研究网络结构修改,容易上手,比TensorFlow的臃肿好多了。对于深度学习的初学者,Pytorch值得推荐。今天主要主要谈谈Pytorch是如何加载预训练模型的参数以及代码的实现过程。

直接加载预选脸模型

如果我们使用的模型和预训练模型完全一样,那么我们就可以直接加载别人的模型,还有一种情况,我们在训练自己模型的过程中,突然中断了,但只要我们保存了之前的模型的参数也可以使用下面的代码直接加载我们保存的模型继续训练,不用从头开始。

model=DPN(*args, **kwargs)
model.load_state_dict(torch.load("DPN.pth"))

这样的加载方式是基于Pytorch使用的模型存储方法:

torch.save(DPN.state_dict(), "DPN.pth")

加载部分预训练模型参数

其实大多数时候我们根据自己的任物所提出的模型是在一些公开模型的基础上改变而来,其中公开模型的参数我们没有必要在从头开始训练,只要加载其训练好的模型参数即可,这样有助于提高训练的准确率和我们模型的泛化能力。

model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)
 http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}
 pretrained_dict=model_zoo.load_url(http['url'])
 model_dict = model.state_dict()
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys 
 model_dict.update(pretrained_dict)
 model.load_state_dict(model_dict)
 model = torch.nn.DataParallel(model).cuda()

因为需要删除预训练模型中不匹配的的键,也就是层的名字。

以上这篇Pytorch加载部分预训练模型的参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python判断质数(素数)的简单方法讲解
May 05 Python
Python中asyncore异步模块的用法及实现httpclient的实例
Jun 28 Python
Python 正则表达式的高级用法
Dec 04 Python
Python机器学习库scikit-learn安装与基本使用教程
Jun 25 Python
pyqt5对用qt designer设计的窗体实现弹出子窗口的示例
Jun 19 Python
详解python中自定义超时异常的几种方法
Jul 29 Python
Python pandas用法最全整理
Aug 04 Python
Pandas 缺失数据处理的实现
Nov 04 Python
pytorch使用 to 进行类型转换方式
Jan 08 Python
从0到1使用python开发一个半自动答题小程序的实现
May 12 Python
python神经网络编程实现手写数字识别
May 27 Python
获取python运行输出的数据并解析存为dataFrame实例
Jul 07 Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
You might like
php下实现伪 url 的超简单方法[转]
2007/09/24 PHP
php下统计用户在线时间的一种尝试
2010/08/26 PHP
thinkphp 一个页面使用2次分页的实现方法
2013/07/15 PHP
删除html标签得到纯文本可处理嵌套的标签
2014/04/28 PHP
jQuery常见开发技巧详细整理
2013/01/02 Javascript
Nodejs学习笔记之Global Objects全局对象
2015/01/13 NodeJs
jQuery获得子元素个数的方法
2015/04/14 Javascript
javascript常见数字进制转换实例分析
2016/04/21 Javascript
全面解析JavaScript里的循环方法之forEach,for-in,for-of
2020/04/20 Javascript
客户端验证用户名和密码的方法详解
2016/06/16 Javascript
详解js产生对象的3种基本方式(工厂模式,构造函数模式,原型模式)
2017/01/09 Javascript
Vue实现点击后文字变色切换方法
2018/02/11 Javascript
详解vue-cli 快速搭建单页应用之遇到的问题及解决办法
2018/03/01 Javascript
在vue.js中使用JSZip实现在前端解压文件的方法
2018/09/05 Javascript
JS原型与继承操作示例
2019/05/09 Javascript
vue2.0 获取从http接口中获取数据,组件开发,路由配置方式
2019/11/04 Javascript
javascript实现简易数码时钟
2020/03/30 Javascript
javascript canvas时钟模拟器
2020/07/13 Javascript
Vue开发中常见的套路和技巧总结
2020/11/24 Vue.js
python list使用示例 list中找连续的数字
2014/01/27 Python
Python pass 语句使用示例
2014/03/11 Python
python3中int(整型)的使用教程
2017/03/23 Python
Python实现连接MySql数据库及增删改查操作详解
2019/04/16 Python
在PYQT5中QscrollArea(滚动条)的使用方法
2019/06/14 Python
用python求一个数组的和与平均值的实现方法
2019/06/29 Python
python实现字符串完美拆分split()的方法
2019/07/16 Python
django drf框架自带的路由及最简化的视图
2019/09/10 Python
使用IDLE的Python shell窗口实例详解
2019/11/19 Python
Python selenium的基本使用方法分析
2019/12/21 Python
如何真正的了解python装饰器
2020/08/14 Python
详解Pytorch显存动态分配规律探索
2020/11/17 Python
Canvas波浪花环的示例代码
2020/08/21 HTML / CSS
HTML+CSS+JavaScript实现图片3D展览的示例代码
2020/10/12 HTML / CSS
中专毕业个人的自荐信格式
2013/09/21 职场文书
班级课外活动总结
2014/07/09 职场文书
浅谈 JavaScript 沙箱Sandbox
2021/11/02 Javascript