pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
操作Windows注册表的简单的Python程序制作教程
Apr 07 Python
在Apache服务器上同时运行多个Django程序的方法
Jul 22 Python
深入解析Python设计模式编程中建造者模式的使用
Mar 02 Python
python3实现暴力穷举博客园密码
Jun 19 Python
python如何读写csv数据
Mar 21 Python
解决python中画图时x,y轴名称出现中文乱码的问题
Jan 29 Python
pandas的连接函数concat()函数的具体使用方法
Jul 09 Python
python3 实现的对象与json相互转换操作示例
Aug 17 Python
python使用pandas抽样训练数据中某个类别实例
Feb 28 Python
Python爬取12306车次信息代码详解
Aug 12 Python
python3.8动态人脸识别的实现示例
Sep 21 Python
matplotlib交互式数据光标实现(mplcursors)
Jan 13 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
php限制文件下载速度的代码
2015/10/20 PHP
微信开发之网页授权获取用户信息(二)
2016/01/08 PHP
thinkPHP5.0框架自动加载机制分析
2017/03/18 PHP
php实现的PDO异常处理操作分析
2018/12/27 PHP
jquery ajax执行后台方法
2010/03/18 Javascript
JavaScript对象之间的转换 jQuery对象和原声DOM
2011/03/07 Javascript
javascript语言结构小记(一)
2011/09/10 Javascript
js弹出的对话窗口永远保持居中显示
2012/12/15 Javascript
JavaScript模板引擎用法实例
2015/07/10 Javascript
超详细的javascript数组方法汇总
2015/11/21 Javascript
AngularJS动态生成div的ID源码解析
2016/08/29 Javascript
vue模板语法-插值详解
2017/03/06 Javascript
jquery Form轻松实现文件上传
2017/05/24 jQuery
react实现一个优雅的图片占位模块组件详解
2017/10/30 Javascript
Three.js开发实现3D地图的实践过程总结
2017/11/20 Javascript
p5.js入门教程之小球动画示例代码
2018/03/15 Javascript
vue项目中运用webpack动态配置打包多种环境域名的方法
2019/06/24 Javascript
[01:16:13]DOTA2-DPC中国联赛 正赛 SAG vs Dragon BO3 第一场 2月22日
2021/03/11 DOTA
简单分析python的类变量、实例变量
2019/08/23 Python
对Python 中矩阵或者数组相减的法则详解
2019/08/26 Python
pytorch模型存储的2种实现方法
2020/02/14 Python
Python中import导入不同目录的模块方法详解
2020/02/18 Python
Python Numpy 控制台完全输出ndarray的实现
2020/02/19 Python
python自动化测试三部曲之unittest框架的实现
2020/10/07 Python
Python爬虫后获取重定向url的两种方法
2021/01/19 Python
CSS3+HTML5+JS 实现一个块的收缩与展开动画效果
2020/11/17 HTML / CSS
浅谈Html5多线程开发之WebWorkers
2018/05/02 HTML / CSS
Zavvi美国:英国娱乐之家
2017/03/19 全球购物
潘多拉珠宝俄罗斯官方网上商店:PANDORA俄罗斯
2020/09/22 全球购物
PHP如何设置和取得Cookie值
2015/06/30 面试题
银行办公室岗位职责
2014/03/10 职场文书
恶搞卫生巾广告词
2014/03/18 职场文书
劲霸男装广告词改编版
2014/03/21 职场文书
卫校毕业生个人自我鉴定
2014/04/28 职场文书
《宝可梦》动画制作25周年到来 官方发布特别纪念视频
2022/04/01 日漫
python如何读取和存储dict()与.json格式文件
2022/06/25 Python