浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中利用Into包整洁地进行数据迁移的教程
Mar 30 Python
python使用arcpy.mapping模块批量出图
Mar 06 Python
TensorFlow实现AutoEncoder自编码器
Mar 09 Python
Python切片工具pillow用法示例
Mar 30 Python
对python 多个分隔符split 的实例详解
Dec 20 Python
在python中使用requests 模拟浏览器发送请求数据的方法
Dec 26 Python
python用requests实现http请求代码实例
Oct 31 Python
python 字段拆分详解
Dec 17 Python
Pytorch.nn.conv2d 过程验证方式(单,多通道卷积过程)
Jan 03 Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 Python
分享3个非常实用的 Python 模块
Mar 03 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
php使用pdo连接并查询sql数据库的方法
2014/12/24 PHP
Yii模型操作之criteria查找数据库的方法
2016/07/15 PHP
PHP7 安装event扩展的实现方法
2019/10/08 PHP
PHP基于openssl实现非对称加密代码实例
2020/06/19 PHP
老鱼 浅谈javascript面向对象编程
2010/03/04 Javascript
15 个 JavaScript Web UI 库
2010/05/19 Javascript
JavaScript开发规范要求(规范化代码)
2010/08/16 Javascript
javascript得到当前页的来路即前一页地址的方法
2014/02/18 Javascript
jQuery中对未来的元素绑定事件用bind、live or on
2014/04/17 Javascript
JavaScript实现Flash炫光波动特效
2015/05/14 Javascript
JS实现双击编辑可修改状态的方法
2015/08/14 Javascript
深入浅析JavaScript中的3DES
2016/08/24 Javascript
NodeJs读取JSON文件格式化时的注意事项
2016/09/25 NodeJs
原生js实现网易轮播图效果
2020/04/10 Javascript
nodejs基础应用
2017/02/03 NodeJs
jQuery上传多张图片带进度条样式(DEMO)
2017/03/02 Javascript
vue v-model表单控件绑定详解
2017/05/17 Javascript
nodejs实现大文件(在线视频)的读取
2020/10/16 NodeJs
jQuery实现增删改查
2020/12/22 jQuery
Python中优化NumPy包使用性能的教程
2015/04/23 Python
Python用5行代码写一个自定义简单二维码
2018/10/21 Python
python列表使用实现名字管理系统
2019/01/30 Python
python中数组和矩阵乘法及使用总结(推荐)
2019/05/18 Python
python 函数的缺省参数使用注意事项分析
2019/09/17 Python
python pip如何手动安装二进制包
2020/09/30 Python
Html5 滚动穿透的方法
2019/05/13 HTML / CSS
SHEIN香港:价格实惠的女性时尚服装
2018/08/14 全球购物
巴西服装和鞋子购物网站:Marisa
2018/10/25 全球购物
贝斯特韦斯特酒店集团官网:Best Western
2019/01/03 全球购物
大学生家政服务项目创业计划书
2014/01/30 职场文书
乡镇党员干部四风对照检查材料思想汇报
2014/09/27 职场文书
领导干部作风建设自查报告
2014/10/23 职场文书
白酒代理协议书范本
2014/10/26 职场文书
房屋买卖协议样本
2014/11/16 职场文书
3.15消费者权益日活动总结
2015/02/09 职场文书
雾霾停课通知
2015/04/24 职场文书