关于pytorch中全连接神经网络搭建两种模式详解


Posted in Python onJanuary 14, 2020

pytorch搭建神经网络是很简单明了的,这里介绍两种自己常用的搭建模式:

import torch
import torch.nn as nn

first:

class NN(nn.Module):
 def __init__(self):
  super(NN,self).__init__()
  self.model=nn.Sequential(
   nn.Linear(30,40),
   nn.ReLU(),
   nn.Linear(40,60),
   nn.Tanh(),
   nn.Linear(60,10),
   nn.Softmax()
  )
  self.model[0].weight.data.uniform_(-3e-3, 3e-3)
  self.model[0].bias.data.uniform(-1,1)
 def forward(self,states):
  return self.model(states)

这一种是将整个网络写在一个Sequential中,网络参数设置可以在网络搭建好后单独设置:self.model[0].weight.data.uniform_(-3e-3,3e-3),这是设置第一个linear的权重是(-3e-3,3e-3)之间的均匀分布,bias是-1至1之间的均匀分布。

second:

class NN1(nn.Module):
 def __init__(self):
  super(NN1,self).__init__()
  self.Linear1=nn.Linear(30,40)
  self.Linear1.weight.data.fill_(-0.1)
  #self.Linear1.weight.data.uniform_(-3e-3,3e-3)
  self.Linear1.bias.data.fill_(-0.1)
  self.layer1=nn.Sequential(self.Linear1,nn.ReLU())

  self.Linear2=nn.Linear(40,60)
  self.layer2=nn.Sequential(self.Linear2,nn.Tanh())

  self.Linear3=nn.Linear(60,10)
  self.layer3=nn.Sequential(self.Linear3,nn.Softmax())


 def forward(self,states):
  return self.model(states)

网络参数的设置可以在定义完线性层之后直接设置如这里对于第一个线性层是这样设置:self.Linear1.weight.data.fill_(-0.1),self.Linear1.bias.data.fill_(-0.1)。

你可以看一下这样定义完的参数的效果:

Net=NN()
print("0:",Net.model[0])
print("weight:",type(Net.model[0].weight))
print("weight:",type(Net.model[0].weight.data))
print("bias",Net.model[0].bias.data)
print('1:',Net.model[1])
#print("weight:",Net.model[1].weight.data)
print('2:',Net.model[2])
print('3:',Net.model[3])
#print(Net.model[-1])

Net1=NN1()
print(Net1.Linear1.weight.data)

输出:

0: Linear (30 -> 40)
weight: <class 'torch.nn.parameter.Parameter'>
weight: <class 'torch.FloatTensor'>
bias 
-0.6287
-0.6573
-0.0452
 0.9594
-0.7477
 0.1363
-0.1594
-0.1586
 0.0360
 0.7375
 0.2501
-0.1371
 0.8359
-0.9684
-0.3886
 0.7200
-0.3906
 0.4911
 0.8081
-0.5449
 0.9872
 0.2004
 0.0969
-0.9712
 0.0873
 0.4562
-0.4857
-0.6013
 0.1651
 0.3315
-0.7033
-0.7440
 0.6487
 0.9802
-0.5977
 0.3245
 0.7563
 0.5596
 0.2303
-0.3836
[torch.FloatTensor of size 40]

1: ReLU ()
2: Linear (40 -> 60)
3: Tanh ()

-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
   ...    ⋱    ...   
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
-0.1000 -0.1000 -0.1000 ... -0.1000 -0.1000 -0.1000
[torch.FloatTensor of size 40x30]


Process finished with exit code 0

这里要注意self.Linear1.weight的类型是网络的parameter。而self.Linear1.weight.data是FloatTensor。

以上这篇关于pytorch中全连接神经网络搭建两种模式详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中对列表排序实例
Jan 04 Python
利用Python将时间或时间间隔转为ISO 8601格式方法示例
Sep 05 Python
Python3匿名函数lambda介绍与使用示例
May 18 Python
Python 3.8中实现functools.cached_property功能
May 29 Python
python安装requests库的实例代码
Jun 25 Python
Python基于Opencv来快速实现人脸识别过程详解(完整版)
Jul 11 Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 Python
Python数据持久化存储实现方法分析
Dec 21 Python
Python中if有多个条件处理方法
Feb 26 Python
Python2与Python3关于字符串编码处理的差别总结
Sep 07 Python
Django ModelForm组件原理及用法详解
Oct 12 Python
详解python的内存分配机制
May 10 Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
You might like
PHP 5.3新特性命名空间规则解析及高级功能
2010/03/11 PHP
由php的call_user_func传reference引发的思考
2010/07/23 PHP
ThinkPHP模版中导入CSS和JS文件的方法
2014/11/29 PHP
用jQuery简化JavaScript开发分析
2009/02/19 Javascript
javascript sudoku 数独智力游戏生成代码
2010/03/27 Javascript
修改或扩展jQuery原生方法的代码实例
2015/01/13 Javascript
嵌入式iframe子页面与父页面js通信的方法
2015/01/20 Javascript
JS实现点击上移下移LI行数据的方法
2015/08/05 Javascript
bootstrap导航、选项卡实现代码
2016/12/28 Javascript
基于jQuery选择器之表单对象属性筛选选择器的实例
2017/09/19 jQuery
解决vue.js在编写过程中出现空格不规范报错的问题
2017/09/20 Javascript
laydate如何根据开始时间或者结束时间限制范围
2018/11/15 Javascript
了解JavaScript函数中的默认参数
2019/05/30 Javascript
vue实现员工信息录入功能
2020/06/11 Javascript
在js文件中引入(调用)另一个js文件的三种方法
2020/09/11 Javascript
JS获取一个字符串中指定字符串第n次出现的位置
2021/02/10 Javascript
python 中文乱码问题深入分析
2011/03/13 Python
Flask框架的学习指南之用户登录管理
2016/11/20 Python
在Python中执行系统命令的方法示例详解
2017/09/14 Python
详解Python安装scrapy的正确姿势
2018/06/26 Python
详解Python locals()的陷阱
2019/03/26 Python
Python中的 is 和 == 以及字符串驻留机制详解
2019/06/28 Python
详解matplotlib中pyplot和面向对象两种绘图模式之间的关系
2021/01/22 Python
使用before和:after伪类制作css3圆形按钮
2014/04/08 HTML / CSS
维多利亚的秘密官方旗舰店:VICTORIA’S SECRET
2018/04/02 全球购物
在校生党员自我评价
2013/09/25 职场文书
食品安全工作实施方案
2014/03/26 职场文书
委托书怎样写
2014/08/30 职场文书
“四风”问题整改措施和努力方向
2014/09/20 职场文书
债务纠纷委托书范本
2014/10/14 职场文书
协议书范文
2015/01/27 职场文书
大连导游词
2015/02/12 职场文书
2015年学校工作总结范文
2015/04/20 职场文书
办公室禁烟通知
2015/04/23 职场文书
贷款工作证明模板
2015/06/12 职场文书
mysql使用 not int 子查询隐含陷阱
2022/04/12 MySQL