Pytorch 抽取vgg各层并进行定制化处理的方法


Posted in Python onAugust 20, 2019

工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式

def define_vgg(vgg,input_channels,endlayer,use_maxpool=False): 
  vgg_ad = copy.deepcopy(vgg)
  model = nn.Sequential()
  i = 0
  for layer in list(vgg_ad.features):
    if i > endlayer:
      break
    if isinstance(layer, nn.Conv2d) and i is 0:
      name = "conv_" + str(i)
      layer = nn.Conv2d(input_channels,
               layer.out_channels,
               layer.kernel_size,
               stride = layer.stride,
               padding=layer.padding)
      model.add_module(name, layer)
    if isinstance(layer, nn.Conv2d):
      name = "conv_" + str(i)
      model.add_module(name, layer)
 
    if isinstance(layer, nn.ReLU):
      name = "leakyrelu_" + str(i)
      layer = nn.LeakyReLU(inplace=True) 
      model.add_module(name, layer)
 
    if isinstance(layer, nn.MaxPool2d):
      name = "pool_" + str(i)
      if use_maxpool:
        model.add_module(name, layer)
      else:
        avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
        model.add_module(name, avgpool)
    i += 1
  return model

函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。

下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。

Sequential(
 (conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace)
 (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace)
 (pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0)
 (conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace)
 (conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

以上这篇Pytorch 抽取vgg各层并进行定制化处理的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬取Coursera课程资源的详细过程
Nov 04 Python
Pyhthon中使用compileall模块编译源文件为pyc文件
Apr 28 Python
python通过函数属性实现全局变量的方法
May 16 Python
删除python pandas.DataFrame 的多重index实例
Jun 08 Python
Python定义二叉树及4种遍历方法实例详解
Jul 05 Python
Python函数any()和all()的用法及区别介绍
Sep 14 Python
解决python 无法加载downsample模型的问题
Oct 25 Python
对Python中TKinter模块中的Label组件实例详解
Jun 14 Python
jupyter notebook 中输出pyecharts图实例
Apr 23 Python
python实现吃苹果小游戏
Mar 21 Python
Python 机器学习工具包SKlearn的安装与使用
May 14 Python
Django实现翻页的示例代码
May 24 Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 #Python
pytorch在fintune时将sequential中的层输出方法,以vgg为例
Aug 20 #Python
python实现证件照换底功能
Aug 20 #Python
pytorch多进程加速及代码优化方法
Aug 19 #Python
用Pytorch训练CNN(数据集MNIST,使用GPU的方法)
Aug 19 #Python
You might like
PHP实现数字补零功能的2个函数介绍
2014/05/12 PHP
PHP中in_array函数使用的问题与解决办法
2016/09/11 PHP
PHP重定向与伪静态区别
2017/02/19 PHP
thinkphp框架使用JWTtoken的方法详解
2019/10/10 PHP
jQuery学习笔记(3)--用jquery(插件)实现多选项卡功能
2013/04/08 Javascript
JavaScript字符串对象的concat方法实例(用于连接两个或多个字符串)
2014/10/16 Javascript
AngularJS入门教程之AngularJS模型
2016/04/18 Javascript
js验证框架之RealyEasy验证详解
2016/06/08 Javascript
JS简单实现浮动窗口效果示例
2016/09/07 Javascript
BootStrap 表单控件之单选按钮水平排列
2017/05/23 Javascript
为什么我们要做三份 Webpack 配置文件
2017/09/18 Javascript
JS实现table表格固定表头且表头随横向滚动而滚动
2017/10/26 Javascript
解决vue-quill-editor上传内容由于图片是base64的导致字符太长的问题
2018/08/20 Javascript
详解关于React-Router4.0跳转不置顶解决方案
2019/05/10 Javascript
jquery实现自定义树形表格的方法【自定义树形结构table】
2019/07/12 jQuery
解决vue组件中click事件失效的问题
2019/11/09 Javascript
python控制台中实现进度条功能
2015/11/10 Python
使用Python写一个贪吃蛇游戏实例代码
2017/08/21 Python
Python图形绘制操作之正弦曲线实现方法分析
2017/12/25 Python
Python实现的视频播放器功能完整示例
2018/02/01 Python
python 去除txt文本中的空格、数字、特定字母等方法
2018/07/24 Python
详解利用django中间件django.middleware.csrf.CsrfViewMiddleware防止csrf攻击
2018/10/09 Python
python进行文件对比的方法
2018/12/24 Python
Python实现滑动平均(Moving Average)的例子
2019/08/24 Python
使用OpCode绕过Python沙箱的方法详解
2019/09/03 Python
Python使用grequests(gevent+requests)并发发送请求过程解析
2019/09/25 Python
深入浅析python变量加逗号,的含义
2020/02/22 Python
Python预测2020高考分数和录取情况
2020/07/08 Python
让IE6、IE7、IE8支持CSS3的脚本
2010/07/20 HTML / CSS
纯css3制作网站后台管理面板
2014/12/30 HTML / CSS
银行贷款承诺书
2014/03/29 职场文书
企业口号大全
2014/06/12 职场文书
万能检讨书开头与结尾怎么写
2015/02/17 职场文书
技术员岗位职责范本
2015/04/11 职场文书
工作态度不好检讨书
2015/05/06 职场文书
唱歌比赛拉拉队口号
2015/12/25 职场文书