pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 元组(Tuple)操作详解
Mar 11 Python
在Python中使用异步Socket编程性能测试
Jun 25 Python
Python中sort和sorted函数代码解析
Jan 25 Python
python 爬虫 批量获取代理ip的实例代码
May 22 Python
python使用matplotlib库生成随机漫步图
Aug 27 Python
python图形工具turtle绘制国际象棋棋盘
May 23 Python
Python通过VGG16模型实现图像风格转换操作详解
Jan 16 Python
python计算二维矩形IOU实例
Jan 18 Python
对Tensorflow中tensorboard日志的生成与显示详解
Feb 04 Python
基于Python实现下载网易音乐代码实例
Aug 10 Python
浅谈python锁与死锁问题
Aug 14 Python
浅谈Python数学建模之线性规划
Jun 23 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
1 Tube Radio
2021/03/02 无线电
php面向对象全攻略 (十六) 对象的串行化
2009/09/30 PHP
ajax+php打造进度条代码[readyState各状态说明]
2010/04/12 PHP
使用PHP编写发红包程序
2015/07/22 PHP
PHP实现的迷你漂流瓶
2015/07/29 PHP
PHP 实现人民币小写转换成大写的方法及大小写转换函数
2017/11/17 PHP
PHP二维数组分页2种实现方法解析
2020/07/09 PHP
轻量级 JS ToolTip提示效果
2010/07/20 Javascript
关于jQuery参考实例2.0 用jQuery选择元素
2013/04/07 Javascript
javascript的内存管理详解
2013/08/07 Javascript
css样式标签和js语法属性区别
2013/11/06 Javascript
用js代码改变单选框选中状态的简单实例
2013/12/18 Javascript
jquery通过visible来判断标签是否显示或隐藏
2014/05/08 Javascript
原生js实现类似弹窗抖动效果
2015/04/02 Javascript
jquery操作select元素和option的实例代码
2016/02/03 Javascript
用JS实现轮播图效果(二)
2016/06/26 Javascript
Vue2.0组件间数据传递示例
2017/03/07 Javascript
Bootstrap table使用方法汇总
2017/11/17 Javascript
如何以Angular的姿势打开Font-Awesome详解
2018/04/22 Javascript
[00:37]DOTA2上海特级锦标赛 OG战队宣传片
2016/03/03 DOTA
[59:35]DOTA2上海特级锦标赛主赛事日 - 3 败者组第三轮#1COL VS Alliance第二局
2016/03/04 DOTA
Python设计模式之观察者模式实例
2014/04/26 Python
Django框架会话技术实例分析【Cookie与Session】
2019/05/24 Python
TensorFlow2.X使用图片制作简单的数据集训练模型
2020/04/08 Python
MxNet预训练模型到Pytorch模型的转换方式
2020/05/25 Python
CSS3模拟动画下拉菜单效果
2017/04/12 HTML / CSS
HTML5 placeholder属性详解
2016/06/22 HTML / CSS
美国殿堂级滑板、冲浪、滑雪服装品牌:Volcom(钻石)
2017/04/20 全球购物
bonprix匈牙利:女士、男士和儿童服装
2019/07/19 全球购物
莫斯科珠宝厂官方网站:Miuz
2020/09/19 全球购物
便利店的创业计划书
2014/01/15 职场文书
2015年医院护理部工作总结
2015/04/23 职场文书
2016七一建党节慰问信
2015/11/30 职场文书
《草虫的村落》教学反思
2016/02/20 职场文书
【DOTA2】总决赛血虐~ XTREME GAMING vs MAGMA - OGA DOTA PIT 2022 CN
2022/04/02 DOTA
安装Ruby和 Rails的详细步骤
2022/04/19 Ruby