用pytorch的nn.Module构造简单全链接层实例


Posted in Python onJanuary 14, 2020

python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架

1、先定义一个类Linear,继承nn.Module

import torch as t
from torch import nn
from torch.autograd import Variable as V
 
class Linear(nn.Module):

  '''因为Variable自动求导,所以不需要实现backward()'''
  def __init__(self, in_features, out_features):
    super().__init__()
    self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
    self.b = nn.Parameter( t.randn( out_features ) )   #偏值b
  
  def forward( self, x ): #参数 x 是一个Variable对象
    x = x.mm( self.w )
    return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状

2、验证一下

layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out

#成功运行,结果如下:

tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)

下面利用Linear构造一个多层网络

class Perceptron( nn.Module ):
  def __init__( self,in_features, hidden_features, out_features ):
    super().__init__()
    self.layer1 = Linear( in_features , hidden_features )
    self.layer2 = Linear( hidden_features, out_features )
  def forward ( self ,x ):
    x = self.layer1( x )
    x = t.sigmoid( x ) #用sigmoid()激活函数
    return self.layer2( x )

测试一下

perceptron = Perceptron ( 5,3 ,1 )
 
for name,param in perceptron.named_parameters(): 
  print( name, param.size() )

输出如预期:

layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])

以上这篇用pytorch的nn.Module构造简单全链接层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现问号表达式(?)的方法
Nov 27 Python
对于Python的Django框架部署的一些建议
Apr 09 Python
Python中编写ORM框架的入门指引
Apr 29 Python
Django1.9 加载通过ImageField上传的图片方法
May 25 Python
python远程连接服务器MySQL数据库
Jul 02 Python
Python字符串对象实现原理详解
Jul 01 Python
详解Python list和numpy array的存储和读取方法
Nov 06 Python
Python监听剪切板实现方法代码实例
Nov 11 Python
call在Python中改进数列的实例讲解
Dec 09 Python
详解解决jupyter不能使用pytorch的问题
Feb 18 Python
conda安装tensorflow和conda常用命令小结
Feb 20 Python
python实现对doc、txt、xls等文档的读写操作
Apr 02 Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
You might like
php设计模式 State (状态模式)
2011/06/26 PHP
php PDO实现的事务回滚示例
2017/03/23 PHP
Prototype使用指南之selector.js
2007/01/10 Javascript
翻译整理的jQuery使用查询手册
2007/03/07 Javascript
为jQuery增加join方法的实现代码
2010/11/28 Javascript
定义JavaScript二维数组采用定义数组的数组来实现
2012/12/09 Javascript
使用POST方式弹出窗口的两种方法示例介绍
2014/01/29 Javascript
javascript实现平滑无缝滚动
2020/08/09 Javascript
AngularJS基础 ng-submit 指令简单示例
2016/08/03 Javascript
AngularJS 与Bootstrap实现表格分页实例代码
2016/10/14 Javascript
JS中使用正则表达式g模式和非g模式的区别
2017/04/01 Javascript
详解在 Angular 项目中添加 clean-blog 模板
2017/07/04 Javascript
VUE element-ui 写个复用Table组件的示例代码
2017/11/18 Javascript
Vue2.0 事件的广播与接收(观察者模式)
2018/03/14 Javascript
vue addRoutes实现动态权限路由菜单的示例
2018/05/15 Javascript
简单了解Javscript中兄弟ifream的方法调用
2019/06/17 Javascript
vue实现验证用户名是否可用
2021/01/20 Vue.js
[01:04:35]2018DOTA2亚洲邀请赛 4.3 突围赛 Secret vs VG 第一场
2018/04/04 DOTA
Python中的__new__与__init__魔术方法理解笔记
2014/11/08 Python
python urllib爬虫模块使用解析
2019/09/05 Python
Python selenium模拟手动操作实现无人值守刷积分功能
2020/05/13 Python
python 元组的使用方法
2020/06/09 Python
python cv2.resize函数high和width注意事项说明
2020/07/05 Python
CSS3实现时间轴效果
2016/07/11 HTML / CSS
HTML5 Canvas实现图片缩放、翻转、颜色渐变的代码示例
2016/02/28 HTML / CSS
使用canvas对多图片拼合并导出图片的方法
2018/08/28 HTML / CSS
美国眼镜网:GlassesUSA
2017/09/07 全球购物
什么是符号链接,什么是硬链接?符号链接与硬链接的区别是什么?
2014/01/19 面试题
银行职员自我鉴定
2014/04/20 职场文书
2014年党员创先争优承诺书
2014/05/29 职场文书
2014年医务科工作总结
2014/12/18 职场文书
帝企鹅日记观后感
2015/06/10 职场文书
遗失证明范文
2015/06/19 职场文书
暑期辅导班宣传单
2015/07/14 职场文书
MySQL 如何限制一张表的记录数
2021/09/14 MySQL
直播实况, OMG破敌三路五十分钟大战神技局摩托车
2022/04/01 DOTA