pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字符串拼接、截取及替换方法总结分析
Apr 13 Python
Python+matplotlib+numpy实现在不同平面的二维条形图
Jan 02 Python
Python中常见的异常总结
Feb 20 Python
对Django外键关系的描述
Jul 26 Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 Python
python2和python3应该学哪个(python3.6与python3.7的选择)
Oct 01 Python
提升python处理速度原理及方法实例
Dec 25 Python
Python面向对象程序设计之继承、多态原理与用法详解
Mar 23 Python
python opencv 实现读取、显示、写入图像的方法
Jun 08 Python
python多线程semaphore实现线程数控制的示例
Aug 10 Python
Python列表的深复制和浅复制示例详解
Feb 12 Python
pycharm安装深度学习pytorch的d2l包失败问题解决
Mar 25 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
php curl基本操作详解
2013/07/23 PHP
php的sprintf函数的用法 控制浮点数格式
2014/02/14 PHP
php反射应用示例
2014/02/25 PHP
PHP删除指定目录中的所有目录及文件的方法
2015/02/26 PHP
php里array_work用法实例分析
2015/07/13 PHP
浅谈PHP各环境下的伪静态配置
2019/03/13 PHP
从面试题学习Javascript 面向对象(创建对象)
2012/03/30 Javascript
JS如何将UTC格式时间转本地格式
2013/09/04 Javascript
ajax请求get与post的区别总结
2013/11/04 Javascript
回车直接实现点击某按钮的效果即触发单击事件
2014/02/27 Javascript
nodejs分页类代码分享
2014/06/17 NodeJs
JavaScript中的console.log()函数详细介绍
2014/12/29 Javascript
javascript中offset、client、scroll的属性总结
2015/08/13 Javascript
jQuery的三种bind/One/Live/On事件绑定使用方法
2017/02/23 Javascript
JS和canvas实现俄罗斯方块
2017/03/14 Javascript
JS简单实现自定义右键菜单实例
2017/05/31 Javascript
JS操作时间 - UNIX时间戳的简单介绍(必看篇)
2017/08/16 Javascript
vue 修改 data 数据问题并实时显示的方法
2018/08/27 Javascript
为jquery的ajax请求添加超时timeout时间的操作方法
2018/09/04 jQuery
基于javascript处理二进制图片流过程详解
2020/06/08 Javascript
Vue项目中使用mock.js的完整步骤
2021/01/12 Vue.js
[03:21]【TI9纪实】Old Boys
2019/08/23 DOTA
Python+selenium实现截图图片并保存截取的图片
2018/01/05 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
2018/04/26 Python
python标识符命名规范原理解析
2020/01/10 Python
python super用法及原理详解
2020/01/20 Python
python实现读取类别频数数据画水平条形图案例
2020/04/24 Python
Python接口测试文件上传实例解析
2020/05/22 Python
理肤泉美国官网:La Roche-Posay
2018/01/17 全球购物
美国现代家具购物网站:LexMod
2019/01/09 全球购物
护理专业毕业生推荐信
2013/10/31 职场文书
社区公民道德宣传日活动总结
2015/03/23 职场文书
2015年库房工作总结
2015/04/30 职场文书
毕业实习证明范本
2015/06/16 职场文书
财务年终工作总结大全
2019/06/20 职场文书
斗罗大陆八大特殊魂兽,龙族始祖排榜首,第五最残忍(翠魔鸟)
2022/03/18 国漫