pytorch在fintune时将sequential中的层输出方法,以vgg为例


Posted in Python onAugust 20, 2019

有时候我们在fintune时发现pytorch把许多层都集合在一个sequential里,但是我们希望能把中间层的结果引出来做下一步操作,于是我自己琢磨了一个方法,以vgg为例,有点僵硬哈!

首先pytorch自带的vgg16模型的网络结构如下:

VGG(
 (features): Sequential(
 (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): ReLU(inplace)
 (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (3): ReLU(inplace)
 (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (6): ReLU(inplace)
 (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (8): ReLU(inplace)
 (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (11): ReLU(inplace)
 (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (13): ReLU(inplace)
 (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (15): ReLU(inplace)
 (16): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (17): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (18): ReLU(inplace)
 (19): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (20): ReLU(inplace)
 (21): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (22): ReLU(inplace)
 (23): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 (24): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (25): ReLU(inplace)
 (26): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (27): ReLU(inplace)
 (28): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (29): ReLU(inplace)
 (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))
 )
 (classifier): Sequential(
 (0): Linear(in_features=25088, out_features=4096)
 (1): ReLU(inplace)
 (2): Dropout(p=0.5)
 (3): Linear(in_features=4096, out_features=4096)
 (4): ReLU(inplace)
 (5): Dropout(p=0.5)
 (6): Linear(in_features=4096, out_features=1000)
 )
)

我们需要fintune vgg16的features部分,并且我希望把3,8, 15, 22, 29这五个作为输出进一步操作。我的想法是自己写一个vgg网络,这个网络参数与pytorch的网络一致但是保证我们需要的层输出在sequential外。于是我写的网络如下:

class our_vgg(nn.Module):
 def __init__(self):
  super(our_vgg, self).__init__()
  self.conv1 = nn.Sequential(
   # conv1
   nn.Conv2d(3, 64, 3, padding=35),
   nn.ReLU(inplace=True),
   nn.Conv2d(64, 64, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv2 = nn.Sequential(
   # conv2
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/2
   nn.Conv2d(64, 128, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(128, 128, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv3 = nn.Sequential(
   # conv3
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/4
   nn.Conv2d(128, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(256, 256, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv4 = nn.Sequential(
   # conv4
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/8
   nn.Conv2d(256, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),

  )
  self.conv5 = nn.Sequential(
   # conv5
   nn.MaxPool2d(2, stride=2, ceil_mode=True), # 1/16
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
   nn.Conv2d(512, 512, 3, padding=1),
   nn.ReLU(inplace=True),
  )


 def forward(self, x):

  conv1 = self.conv1(x)
  conv2 = self.conv2(conv1)
  conv3 = self.conv3(conv2)
  conv4 = self.conv4(conv3)
  conv5 = self.conv5(conv4)

  return conv5

接着就是copy weights了:

def convert_vgg(vgg16):#vgg16是pytorch自带的
 net = our_vgg()# 我写的vgg

 vgg_items = net.state_dict().items()
 vgg16_items = vgg16.items()

 pretrain_model = {}
 j = 0
 for k, v in net.state_dict().iteritems():#按顺序依次填入
  v = vgg16_items[j][1]
  k = vgg_items[j][0]
  pretrain_model[k] = v
  j += 1
 return pretrain_model


## net是我们最后使用的网络,也是我们想要放置weights的网络
net = net()

print ('load the weight from vgg')
pretrained_dict = torch.load('vgg16.pth')
pretrained_dict = convert_vgg(pretrained_dict)
model_dict = net.state_dict()
# 1. 把不属于我们需要的层剔除
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. 把参数存入已经存在的model_dict
model_dict.update(pretrained_dict) 
# 3. 加载更新后的model_dict
net.load_state_dict(model_dict)
print ('copy the weight sucessfully')

这样我就基本达成目标了,注意net也就是我们要使用的网络fintune部分需要和our_vgg一致。

以上这篇pytorch在fintune时将sequential中的层输出方法,以vgg为例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python编写分析Python程序性能的工具的教程
Apr 01 Python
用ReactJS和Python的Flask框架编写留言板的代码示例
Dec 19 Python
python 自动化将markdown文件转成html文件的方法
Sep 23 Python
python3使用requests模块爬取页面内容的实战演练
Sep 25 Python
python虚拟环境virtualenv的使用教程
Oct 20 Python
Python使用re模块实现信息筛选的方法
Apr 29 Python
python如何发布自已pip项目的方法步骤
Oct 09 Python
Python绘制堆叠柱状图的实例
Jul 09 Python
Python Process多进程实现过程
Oct 22 Python
python 视频逐帧保存为图片的完整实例
Dec 10 Python
如何使用python切换hosts文件
Apr 29 Python
Python Switch Case三种实现方法代码实例
Jun 18 Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
You might like
php各种编码集详解和以及在什么情况下进行使用
2011/09/11 PHP
php导出excel格式数据问题
2014/03/11 PHP
PHP+iFrame实现页面无需刷新的异步文件上传
2014/09/16 PHP
php生成图片验证码
2015/06/09 PHP
php实现的用户查询类实例
2015/06/18 PHP
php-msf源码详解
2017/12/25 PHP
Laravel 模型使用软删除-左连接查询-表起别名示例
2019/10/24 PHP
PHP超级全局变量【$GLOBALS,$_SERVER,$_REQUEST等】用法实例分析
2019/12/11 PHP
js操作二级联动实现代码
2010/07/27 Javascript
jQuery UI AutoComplete 使用说明
2011/06/20 Javascript
返回页面顶部top按钮通过锚点实现(自写)
2013/08/30 Javascript
javascript工厂方式定义对象
2014/12/26 Javascript
jQuery支持添加事件的日历特效代码分享(3种样式)
2015/08/24 Javascript
js的各种排序算法实现(总结)
2016/07/23 Javascript
用js写的一个路由(简单实例)
2016/09/24 Javascript
laydate日历控件使用方法详解
2017/11/20 Javascript
Node.js成为Web应用开发最佳选择的原因
2018/02/05 Javascript
浅谈如何使用webpack构建多页面应用
2018/05/30 Javascript
vue中v-for循环给标签属性赋值的方法
2018/10/18 Javascript
微信小程序如何使用globalData的方法
2019/06/06 Javascript
JavaScript实现与web通信的方法详解
2020/08/07 Javascript
js实现简单的轮播图效果
2020/12/13 Javascript
[18:32]DOTA2 HEROS教学视频教你分分钟做大人-谜团
2014/06/12 DOTA
Python3 Tkinter选择路径功能的实现方法
2019/06/14 Python
python读取图像矩阵文件并转换为向量实例
2020/06/18 Python
python多线程爬取西刺代理的示例代码
2021/01/30 Python
科颜氏加拿大官方网站: Kiehl’s加拿大
2016/08/16 全球购物
美国最大的袜子制造商和零售商:Renfro Socks
2017/09/03 全球购物
Nº21官方在线商店:numeroventuno.com
2019/09/26 全球购物
澳洲Chemist Direct药房中文网:澳洲大型线上直邮药房
2019/11/04 全球购物
如何开启linux的ssh服务
2015/02/14 面试题
学习十八大的心得体会
2014/09/01 职场文书
2015初一年级组工作总结
2015/07/24 职场文书
golang interface判断为空nil的实现代码
2021/04/24 Golang
vue-cropper组件实现图片切割上传
2021/05/27 Vue.js
JavaScript组合继承详解
2021/11/07 Javascript