使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)


Posted in Python onJanuary 18, 2020

本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。

构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。

加载预训练alexnet之后,可以print出来查看模型的结构及信息:

使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)

model = models.alexnet(pretrained=True)
print(model)

分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。

下面放出完整的搭建代码:

import torch.nn as nn
from torchvision import models

class BuildAlexNet(nn.Module):
  def __init__(self, model_type, n_output):
    super(BuildAlexNet, self).__init__()
    self.model_type = model_type
    if model_type == 'pre':
      model = models.alexnet(pretrained=True)
      self.features = model.features
      fc1 = nn.Linear(9216, 4096)
      fc1.bias = model.classifier[1].bias
      fc1.weight = model.classifier[1].weight
      
      fc2 = nn.Linear(4096, 4096)
      fc2.bias = model.classifier[4].bias
      fc2.weight = model.classifier[4].weight
      
      self.classifier = nn.Sequential(
          nn.Dropout(),
          fc1,
          nn.ReLU(inplace=True),
          nn.Dropout(),
          fc2,
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output)) 
      #或者直接修改为
#      model.classifier[6]==nn.Linear(4096,n_output)
#      self.classifier = model.classifier
    if model_type == 'new':
      self.features = nn.Sequential(
          nn.Conv2d(3, 64, 11, 4, 2),
          nn.ReLU(inplace = True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(64, 192, 5, 1, 2),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0),
          nn.Conv2d(192, 384, 3, 1, 1),
          nn.ReLU(inplace = True),
          nn.Conv2d(384, 256, 3, 1, 1),
          nn.ReLU(inplace=True),
          nn.MaxPool2d(3, 2, 0))
      self.classifier = nn.Sequential(
          nn.Dropout(),
          nn.Linear(9216, 4096),
          nn.ReLU(inplace=True),
          nn.Dropout(),
          nn.Linear(4096, 4096),
          nn.ReLU(inplace=True),
          nn.Linear(4096, n_output))
      
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    out = self.classifier(x)
    return out

微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。

网络搭好之后可以小小的测试一下以检验维度是否正确。

import numpy as np
from torch.autograd import Variable
import torch

if __name__ == '__main__':
  model_type = 'pre'
  n_output = 10
  alexnet = BuildAlexNet(model_type, n_output)
  print(alexnet)
  
  x = np.random.rand(1,3,224,224)
  x = x.astype(np.float32)
  x_ts = torch.from_numpy(x)
  x_in = Variable(x_ts)
  y = alexnet(x_in)

这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。

输出y.data.numpy()可得10维输出,表明网络搭建正确。

以上这篇使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之玩转字符串(2)更新篇
Sep 28 Python
Python数据结构与算法之字典树实现方法示例
Dec 13 Python
Python回文字符串及回文数字判定功能示例
Mar 20 Python
Python列表解析配合if else的方法
Jun 23 Python
Python整数对象实现原理详解
Jul 01 Python
Django RBAC权限管理设计过程详解
Aug 06 Python
Python Django模板之模板过滤器与自定义模板过滤器示例
Oct 18 Python
使用Python的networkx绘制精美网络图教程
Nov 21 Python
python爬虫开发之使用python爬虫库requests,urllib与今日头条搜索功能爬取搜索内容实例
Mar 10 Python
Python pandas 列转行操作详解(类似hive中explode方法)
May 18 Python
Java byte数组操纵方式代码实例解析
Jul 22 Python
Python第三方库安装缓慢的解决方法
Feb 06 Python
selenium 多窗口切换的实现(windows)
Jan 18 #Python
pytorch实现建立自己的数据集(以mnist为例)
Jan 18 #Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
You might like
PHP 面向对象详解
2012/09/13 PHP
PHP中的事务使用实例
2015/05/26 PHP
PHP构造二叉树算法示例
2017/06/21 PHP
PHP simplexml_import_dom()函数讲解
2019/02/03 PHP
PHP+jQuery实现即点即改功能示例
2019/02/21 PHP
PHP中ltrim()函数的用法与实例讲解
2019/03/28 PHP
表单提交验证类
2006/07/14 Javascript
Jquery多选下拉列表插件jquery multiselect功能介绍及使用
2013/05/24 Javascript
得到form下的所有的input的js代码
2013/11/07 Javascript
node.js中的fs.statSync方法使用说明
2014/12/16 Javascript
javascript中基本类型和引用类型的区别分析
2015/05/12 Javascript
深入分析下javascript中的[]()+!
2015/07/07 Javascript
JavaScript编写连连看小游戏
2015/07/07 Javascript
深入理解jQuery 事件处理
2016/06/14 Javascript
JS版微信6.0分享接口用法分析
2016/10/13 Javascript
EasyUI 中combotree 默认不能选择父节点的实现方法
2016/11/07 Javascript
微信小程序 详解页面跳转与返回并回传数据
2017/02/13 Javascript
extjs简介_动力节点Java学院整理
2017/07/17 Javascript
JavaScript正则表达式函数总结(常用)
2018/02/22 Javascript
vue组件之间通信实例总结(点赞功能)
2018/12/05 Javascript
vue项目打包之开发环境和部署环境的实现
2020/04/23 Javascript
js实现九宫格布局效果
2020/05/28 Javascript
动态实现element ui的el-table某列数据不同样式的示例
2021/01/22 Javascript
python迭代器与生成器详解
2016/03/10 Python
Python闭包与装饰器原理及实例解析
2020/04/30 Python
HTML5图片层叠的实现示例
2020/07/07 HTML / CSS
伊利莎白雅顿官网:Elizabeth Arden
2016/10/10 全球购物
英国儿童鞋和靴子:Start-Rite
2019/05/06 全球购物
高中毕业自我鉴定
2013/12/13 职场文书
《大作家的小老师》教学反思
2014/04/16 职场文书
城市创卫标语
2014/06/17 职场文书
乡镇消防安全责任书
2014/07/23 职场文书
Python爬虫之爬取哔哩哔哩热门视频排行榜
2021/04/28 Python
详细总结Python常见的安全问题
2021/05/21 Python
PostgreSQL解析URL的方法
2021/08/02 PostgreSQL
Spring Data JPA框架的核心概念和Repository接口
2022/04/28 Java/Android