pytorch __init__、forward与__call__的用法小结


Posted in Python onFebruary 27, 2021

1.介绍

当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?

1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的

2)forward是表示一个前向传播,构建网络层的先后运算步骤

3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?

当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。

在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

pytorch __init__、forward与__call__的用法小结

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def forward(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def __call__(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x

补充:torch/nn目录结构以及__init__.py

torch/nn目录结构以及init.py

pytorch __init__、forward与__call__的用法小结

torch/nn目录结构

__init__.py:

from .modules import *
#nn.modules  导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py  定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init   导入init.py   参数初始化
from . import utils
#nn.utils  导入utils目录下内容 官网api下nn.utils下api

对于backends, functional.py, _functions 需要在代码前重新Import

例如我们常用的

import torch.nn.functional as F 就是导入了functional.py

backends和_functions是functional.py实现各种函数时所用到的。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python运算符重载用法实例
May 28 Python
使用rst2pdf实现将sphinx生成PDF
Jun 07 Python
python 截取 取出一部分的字符串方法
Mar 01 Python
Python实现读取txt文件并画三维图简单代码示例
Dec 09 Python
numpy中实现ndarray数组返回符合特定条件的索引方法
Apr 17 Python
利用Python如何批量修改数据库执行Sql文件
Jul 29 Python
Python3实现取图片中特定的像素替换指定的颜色示例
Jan 24 Python
关于python多重赋值的小问题
Apr 17 Python
python自动化实现登录获取图片验证码功能
Nov 20 Python
Python统计时间内的并发数代码实例
Dec 28 Python
pytorch 实现模型不同层设置不同的学习率方式
Jan 06 Python
关于win10在tensorflow的安装及在pycharm中运行步骤详解
Mar 16 Python
python 实现有道翻译功能
Feb 26 #Python
Python爬取酷狗MP3音频的步骤
Feb 26 #Python
python利用xpath爬取网上数据并存储到django模型中
Feb 26 #Python
用python 绘制茎叶图和复合饼图
Feb 26 #Python
python lambda的使用详解
Feb 26 #Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 #Python
详解Python openpyxl库的基本应用
Feb 26 #Python
You might like
PHP开发文件系统实例讲解
2006/10/09 PHP
PHP中static关键字原理的学习研究分析
2011/07/18 PHP
Zend Framework教程之配置文件application.ini解析
2016/03/10 PHP
php安装ssh2扩展的方法【Linux平台】
2016/07/20 PHP
详解PHP处理字符串类似indexof的方法函数
2017/06/11 PHP
PHP实现时间比较和时间差计算的方法示例
2017/07/24 PHP
js this函数调用无需再次抓获id,name或标签名
2014/03/03 Javascript
如何让一个json文件显示在表格里【实现代码】
2016/05/09 Javascript
jQuery获取同级元素的简单代码
2016/07/09 Javascript
JavaScript之cookie技术详解
2016/11/18 Javascript
Bootstrap Img 图片样式(推荐)
2016/12/13 Javascript
原生js实现对Ajax的封装(仿jquery)
2017/01/22 Javascript
react-router实现跳转传值的方法示例
2017/05/27 Javascript
JS实现留言板功能
2017/06/17 Javascript
vue封装第三方插件并发布到npm的方法
2017/09/25 Javascript
vue计算属性和监听器实例解析
2018/05/10 Javascript
微信小程序生成海报分享朋友圈的实现方法
2019/05/06 Javascript
js实现随机抽奖
2020/03/19 Javascript
在Python中使用判断语句和循环的教程
2015/04/25 Python
python中的全局变量用法分析
2015/06/09 Python
Python实现控制台进度条功能
2016/01/04 Python
python读取二进制mnist实例详解
2017/05/31 Python
Python字符串格式化%s%d%f详解
2018/02/02 Python
Python2和Python3之间的str处理方式导致乱码的讲解
2019/01/03 Python
python 对字典按照value进行排序的方法
2019/05/09 Python
Python 实现集合Set的示例
2020/12/21 Python
施华洛世奇德国官网:SWAROVSKI德国
2017/02/01 全球购物
美国特价机票专家:Airfarewatchdog
2018/01/24 全球购物
Sneaker Studio乌克兰:购买运动鞋
2018/03/26 全球购物
什么是表空间(tablespace)和系统表空间(System tablespace)
2013/02/25 面试题
社保委托书怎么写
2014/08/02 职场文书
2014年平安夜寄语
2014/12/08 职场文书
创业方案:赚钱的烧烤店该怎样做?
2019/07/05 职场文书
使用CSS3实现按钮悬停闪烁动态特效代码
2021/08/30 HTML / CSS
python创建字典及相关管理操作
2022/04/13 Python
Apache SkyWalking 监控 MySQL Server 实战解析
2022/09/23 Servers