浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中列表(list)操作方法汇总
Aug 18 Python
Python中使用item()方法遍历字典的例子
Aug 26 Python
python自动化测试之从命令行运行测试用例with verbosity
Sep 28 Python
详解Python Socket网络编程
Jan 05 Python
python实现手机通讯录搜索功能
Feb 22 Python
解决python nohup linux 后台运行输出的问题
May 11 Python
python中利用h5py模块读取h5文件中的主键方法
Jun 05 Python
Django自定义用户登录认证示例代码
Jun 30 Python
Python实现常见的几种加密算法(MD5,SHA-1,HMAC,DES/AES,RSA和ECC)
May 09 Python
python语言time库和datetime库基本使用详解
Dec 25 Python
能让Python提速超40倍的神器Cython详解
Jun 24 Python
对象析构函数__del__在Python中何时使用
Mar 22 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
PHP中的自动加载操作实现方法详解
2019/08/06 PHP
PHP文件打开关闭及读写操作示例解析
2020/08/06 PHP
js中的string.format函数代码
2020/08/11 Javascript
js实现的折叠导航示例
2013/11/29 Javascript
js判断当前浏览器类型,判断IE浏览器方法
2014/06/02 Javascript
9个让JavaScript调试更简单的Console命令
2016/11/14 Javascript
JS基于面向对象实现的多个倒计时器功能示例
2017/02/28 Javascript
JS模拟实现ECMAScript5新增的数组方法
2017/03/20 Javascript
jQuery实现的动态文字变化输出效果示例【附演示与demo源码下载】
2017/03/24 jQuery
Vue中跨域及打包部署到nginx跨域设置方法
2019/08/26 Javascript
Vue.js中Line第三方登录api的实现代码
2020/06/29 Javascript
python列表与元组详解实例
2013/11/01 Python
在SAE上部署Python的Django框架的一些问题汇总
2015/05/30 Python
如何使用七牛Python SDK写一个同步脚本及使用教程
2015/08/23 Python
详解Python发送邮件实例
2016/01/10 Python
Python3读取Excel数据存入MySQL的方法
2018/05/04 Python
Python爬虫获取图片并下载保存至本地的实例
2018/06/01 Python
Python人脸识别第三方库face_recognition接口说明文档
2019/05/03 Python
python 求某条线上特定x值或y值的点坐标方法
2019/07/09 Python
Python全局锁中如何合理运用多线程(多进程)
2019/11/06 Python
基于python实现复制文件并重命名
2020/09/16 Python
python 利用jieba.analyse进行 关键词提取
2020/12/17 Python
全面总结使用CSS实现水平垂直居中效果的方法
2016/03/10 HTML / CSS
C#面试常见问题
2013/02/25 面试题
企业为何需要商业计划书
2013/12/26 职场文书
工业学校毕业生自荐信范文
2014/01/03 职场文书
物流专业求职计划书
2014/01/10 职场文书
公司同意接收函
2014/01/13 职场文书
大学生职业生涯规划书模板
2014/01/18 职场文书
乡镇消防工作实施方案
2014/03/27 职场文书
社团活动总结怎么写
2014/06/30 职场文书
python基于scrapy爬取京东笔记本电脑数据并进行简单处理和分析
2021/04/14 Python
Django实现聊天机器人
2021/05/31 Python
如何用python清洗文件中的数据
2021/06/18 Python
centos8安装MongoDB的详细过程
2021/10/24 MongoDB
MySql中的json_extract函数处理json字段详情
2022/06/05 MySQL