python PyTorch预训练示例


Posted in Python onFebruary 11, 2018

前言

最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢。各种设计直接简洁,方便研究,比tensorflow的臃肿好多了。今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理。

直接加载预训练模型

如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型:

my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

当然这样的加载方法是基于PyTorch推荐的存储模型的方法:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

还有第二种加载方法:

my_resnet = torch.load("my_resnet.pth")

加载部分预训练模型

其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

因为需要剔除原模型中不匹配的键,也就是层的名字,所以我们的新模型改变了的层需要和原模型对应层的名字不一样,比如:resnet最后一层的名字是fc(PyTorch中),那么我们修改过的resnet的最后一层就不能取这个名字,可以叫fc_

微改基础模型预训练

对于改动比较大的模型,我们可能需要自己实现一下再加载别人的预训练参数。但是,对于一些基本模型PyTorch中已经有了,而且我只想进行一些小的改动那么怎么办呢?难道我又去实现一遍吗?当然不是。

我们首先看看怎么进行微改模型。

微改基础模型

PyTorch中的torchvision里已经有很多常用的模型了,可以直接调用:

  1. AlexNet
  2. VGG
  3. ResNet
  4. SqueezeNet
  5. DenseNet
import torchvision.models as models

resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

但是对于我们的任务而言有些层并不是直接能用,需要我们微微改一下,比如,resnet最后的全连接层是分1000类,而我们只有21类;又比如,resnet第一层卷积接收的通道是3, 我们可能输入图片的通道是4,那么可以通过以下方法修改:

resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet.fc = nn.Linear(2048, 21)

简单预训练

模型已经改完了,接下来我们就进行简单预训练吧。

我们先从torchvision中调用基本模型,加载预训练模型,然后,重点来了,将其中的层直接替换为我们需要的层即可:

resnet = torchvision.models.resnet152(pretrained=True)
# 原本为1000类,改为10类
resnet.fc = torch.nn.Linear(2048, 10)

其中使用了pretrained参数,会直接加载预训练模型,内部实现和前文提到的加载预训练的方法一样。因为是先加载的预训练参数,相当于模型中已经有参数了,所以替换掉最后一层即可。OK!

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python搭建简易服务器分析与实现
Dec 15 Python
python使用reportlab画图示例(含中文汉字)
Dec 03 Python
python机器学习之神经网络(一)
Dec 20 Python
tensorflow实现softma识别MNIST
Mar 12 Python
高效使用Python字典的清单
Apr 04 Python
Python爬虫基础之XPath语法与lxml库的用法详解
Sep 13 Python
Tesserocr库的正确安装方式
Oct 19 Python
Django objects的查询结果转化为json的三种方式的方法
Nov 07 Python
代码详解django中数据库设置
Jan 28 Python
为何人工智能(AI)首选Python?读完这篇文章你就知道了(推荐)
Apr 06 Python
python时间与Unix时间戳相互转换方法详解
Feb 13 Python
Python 中 sorted 如何自定义比较逻辑
Feb 02 Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
详解tensorflow实现迁移学习实例
Feb 10 #Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
You might like
php中preg_match的isU代表什么意思
2015/10/01 PHP
php、mysql查询当天,查询本周,查询本月的数据实例(字段是时间戳)
2017/02/04 PHP
JavaScript Perfection kill 测试及答案
2010/03/23 Javascript
基于JQuery制作的产品广告效果
2010/12/08 Javascript
jQuery EasyUI API 中文文档 - ValidateBox验证框
2011/10/06 Javascript
jQuery-serialize()输出序列化form表单值的方法
2012/12/26 Javascript
javascript的事件触发器介绍的实现
2014/06/05 Javascript
JavaScript基础知识学习笔记
2014/12/02 Javascript
两行代码轻松搞定JavaScript日期验证
2016/08/03 Javascript
Bootstrap如何激活导航状态
2017/03/22 Javascript
使用Math.max,Math.min获取数组中的最值实例
2017/04/25 Javascript
vue.js实现条件渲染的实例代码
2017/06/22 Javascript
基于rem的移动端响应式适配方案(详解)
2017/07/07 Javascript
nodejs发送http请求时遇到404长时间未响应的解决方法
2017/12/10 NodeJs
Vue封装的可编辑表格插件方法
2018/08/28 Javascript
vue实现移动端悬浮窗效果
2018/12/01 Javascript
JS实现带阴历的日历功能详解
2019/01/24 Javascript
Vue中的transition封装组件的实现方法
2019/08/13 Javascript
Node.js系列之连接DB的方法(3)
2019/08/30 Javascript
超详细小程序定位地图模块全系列开发教学
2020/11/24 Javascript
Python中Continue语句的用法的举例详解
2015/05/14 Python
详解Python中表达式i += x与i = i + x是否等价
2017/02/08 Python
Python实现字符串格式化的方法小结
2017/02/20 Python
Django框架视图层URL映射与反向解析实例分析
2019/07/29 Python
国际化的太阳镜及太阳镜配件零售商:Sunglass Hut
2016/07/26 全球购物
SAZAC的动物连体衣和动物睡衣:Kigurumi Shop
2020/03/14 全球购物
判断单链表中是否存在环
2012/07/16 面试题
大专应届生个人简历的自我评价
2013/10/15 职场文书
茶叶生产计划书
2014/01/10 职场文书
表扬通报怎么写
2015/01/16 职场文书
女性健康知识讲座主持词
2015/07/04 职场文书
ThinkPHP5和ThinkPHP6的区别
2021/03/31 PHP
自从在 IDEA 中用了热部署神器 JRebel 之后,开发效率提升了 10(真棒)
2021/06/26 Java/Android
关于python中模块和重载的问题
2021/11/02 Python
详解Flutter自定义应用程序内键盘的实现方法
2022/06/14 Java/Android
MySQL数据库之内置函数和自定义函数 function
2022/06/16 MySQL