使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则表达式match和search用法实例
Mar 26 Python
Python下线程之间的共享和释放示例
May 04 Python
python实现的守护进程(Daemon)用法实例
Jun 02 Python
利用python批量检查网站的可用性
Sep 09 Python
PyChar学习教程之自定义文件与代码模板详解
Jul 17 Python
python itchat实现微信好友头像拼接图的示例代码
Aug 14 Python
python入门教程 python入门神图一张
Mar 05 Python
对python判断ip是否可达的实例详解
Jan 31 Python
Python安装selenium包详细过程
Jul 23 Python
使用python的turtle绘画滑稽脸实例
Nov 21 Python
利用PyQt5+Matplotlib 绘制静态/动态图的实现代码
Jul 13 Python
pycharm中使用request和Pytest进行接口测试的方法
Jul 31 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
Cannot modify header information错误解决方法
2008/10/08 PHP
PHP数据库万能引擎类adodb配置使用以及实例集锦
2014/06/12 PHP
PHP设置进度条的方法
2015/07/08 PHP
CI框架源码解读之URI.php中_fetch_uri_string()函数用法分析
2016/05/18 PHP
php 静态属性和静态方法区别详解
2017/04/09 PHP
JavaScript的类型简单说明
2010/09/03 Javascript
javascript高级学习笔记整理
2011/08/14 Javascript
jquery如何改变html标签的样式(两种实现方法)
2013/01/16 Javascript
js判断日期时间有效性的方法
2015/10/24 Javascript
JavaScript判断是否是微信浏览器
2016/06/13 Javascript
全屏滚动插件fullPage.js使用实例解析
2016/10/21 Javascript
bootstrap模态框垂直居中效果
2016/12/03 Javascript
jQuery Validate表单验证插件的基本使用方法及功能拓展
2017/01/04 Javascript
面包屑导航详解
2017/12/07 Javascript
JavaScript代码实现txt文件的上传预览功能
2018/03/27 Javascript
Vue-cropper 图片裁剪的基本原理及思路讲解
2018/04/17 Javascript
vue中子组件传递数据给父组件的讲解
2019/01/27 Javascript
如何在面试中手写出javascript节流和防抖函数
2020/10/22 Javascript
记一次python 内存泄漏问题及解决过程
2018/11/29 Python
Python3内置模块random随机方法小结
2019/07/13 Python
Django中create和save方法的不同
2019/08/13 Python
在Pytorch中使用样本权重(sample_weight)的正确方法
2019/08/17 Python
python SocketServer源码深入解读
2019/09/17 Python
TensorFlow索引与切片的实现方法
2019/11/20 Python
python 日志 logging模块详细解析
2020/03/31 Python
Windows下Sqlmap环境安装教程详解
2020/08/04 Python
CSS3之2D与3D变换的实现方法
2019/01/28 HTML / CSS
服装设计专业自荐书范文
2013/12/30 职场文书
开业庆典邀请函
2014/01/08 职场文书
手工社团活动方案
2014/02/17 职场文书
2014年五四青年节活动方案
2014/03/29 职场文书
就业协议书样本
2014/08/20 职场文书
三严三实对照检查材料范文
2014/09/23 职场文书
残联2016年全国助残日活动总结
2016/04/01 职场文书
Python道路车道线检测的实现
2021/06/27 Python
TypeScript中条件类型精读与实践记录
2021/10/05 Javascript