PyTorch如何搭建一个简单的网络


Posted in Python onAugust 24, 2020

1 任务

首先说下我们要搭建的网络要完成的学习任务: 让我们的神经网络学会逻辑异或运算,异或运算也就是俗称的“相同取0,不同取1” 。再把我们的需求说的简单一点,也就是我们需要搭建这样一个神经网络,让我们在输入(1,1)时输出0,输入(1,0)时输出1(相同取0,不同取1),以此类推。

2 实现思路

因为我们的需求需要有两个输入,一个输出,所以我们需要在输入层设置两个输入节点,输出层设置一个输出节点。因为问题比较简单,所以隐含层我们只需要设置10个节点就可以达到不错的效果了,隐含层的激活函数我们采用ReLU函数,输出层我们用Sigmoid函数,让输出保持在0到1的一个范围,如果输出大于0.5,即可让输出结果为1,小于0.5,让输出结果为0.

3 实现过程

我们使用的简单的快速搭建法。

3.1 引入必要库

import torch
import torch.nn as nn
import numpy as np

用pytorch当然要引入torch包,然后为了写代码方便将torch包里的nn用nn来代替,nn这个包就是neural network的缩写,专门用来搭神经网络的一个包。引入numpy是为了创建矩阵作为输入。

3.2 创建训练集

# 构建输入集
x = np.mat('0 0;'
      '0 1;'
      '1 0;'
      '1 1')
x = torch.tensor(x).float()
y = np.mat('1;'
      '0;'
      '0;'
      '1')
y = torch.tensor(y).float()

我个人比较喜欢用np.mat这种方式构建矩阵,感觉写法比较简单,当然你也可以用其他的方法。但是构建完矩阵一定要有这一步 torch.tensor(x).float() ,必须要把你所创建的输入转换成tensor变量。

什么是tensor呢?你可以简单地理解他就是pytorch中用的一种变量,你想用pytorch这个框架就必须先把你的变量转换成tensor变量。而我们这个神经网络会要求你的输入和输出必须是float浮点型的,指的是tensor变量中的浮点型,而你用np.mat创建的输入是int型的,转换成tensor也会自动地转换成tensor的int型,所以要在后面加个.float()转换成浮点型。

这样我们就构建完成了输入和输出(分别是x矩阵和y矩阵),x是四行二列的一个矩阵,他的每一行是一个输入,一次输入两个值,这里我们把所有的输入情况都列了出来。输出y是一个四行一列的矩阵,每一行都是一个输出,对应x矩阵每一行的输入。

3.3 搭建网络

myNet = nn.Sequential( 
  nn.Linear(2,10),
  nn.ReLU(),
  nn.Linear(10,1),
  nn.Sigmoid()
  )
print(myNet)

输出结果:

PyTorch如何搭建一个简单的网络

我们使用nn包中的Sequential搭建网络,这个函数就是那个可以让我们像搭积木一样搭神经网络的一个东西。

nn.Linear(2,10)的意思搭建输入层,里面的2代表输入节点个数,10代表输出节点个数。Linear也就是英文的线性,意思也就是这层不包括任何其它的激活函数,你输入了啥他就给你输出了啥。nn.ReLU()这个就代表把一个激活函数层,把你刚才的输入扔到了ReLU函数中去。 接着又来了一个Linear,最后再扔到Sigmoid函数中去。 2,10,1就分别代表了三个层的个数,简单明了。

3.4 设置优化器

optimzer = torch.optim.SGD(myNet.parameters(),lr=0.05)
loss_func = nn.MSELoss()

对这一步的理解就是,你需要有一个优化的方法来训练你的网络,所以这步设置了我们所要采用的优化方法。

torch.optim.SGD的意思就是采用SGD(随机梯度下降)方法训练,你只需要把你网络的参数和学习率传进去就可以了,分别是 myNet.paramets 和 lr 。 loss_func 这句设置了代价函数,因为我们的这个问题比较简单,所以采用了MSE,也就是均方误差代价函数。

3.5 训练网络

for epoch in range(5000):
  out = myNet(x)
  loss = loss_func(out,y)
  optimzer.zero_grad()
  loss.backward()
  optimzer.step()

我这里设置了一个5000次的循环(可能不需要这么多次),让这个训练的动作迭代5000次。每一次的输出直接用myNet(x),把输入扔进你的网络就得到了输出out(就是这么简单粗暴!),然后用代价函数和你的标准输出y求误差。 清除梯度的那一步是为了每一次重新迭代时清除上一次所求出的梯度,你就把这一步记住就行,初学不用理解太深。 loss.backward() 当然就是让误差反向传播,接着 optimzer.step() 也就是让我们刚刚设置的优化器开始工作。

3.6 测试

print(myNet(x).data)

运行结果:

PyTorch如何搭建一个简单的网络

可以看到这个结果已经非常接近我们期待的结果了,当然你也可以换个数据测试,结果也会是相似的。这里简单解释下为什么我们的代码末尾加上了一个.data,因为我们的tensor变量其实是包含两个部分的,一部分是tensor数据,另一部分是tensor的自动求导参数,我们加上.data意思是输出取tensor中的数据,如果不加的话会输出下面这样:

PyTorch如何搭建一个简单的网络

以上就是PyTorch如何搭建一个简单的网络的详细内容,更多关于PyTorch搭建网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python实现模拟浏览器请求及会话保持操作示例
Jul 30 Python
零基础使用Python读写处理Excel表格的方法
May 02 Python
Python with用法:自动关闭文件进程
Jul 10 Python
Python 根据日志级别打印不同颜色的日志的方法示例
Aug 08 Python
Python 函数用法简单示例【定义、参数、返回值、函数嵌套】
Sep 20 Python
解决pycharm下pyuic工具使用的问题
Apr 08 Python
pandas分组聚合详解
Apr 10 Python
Python jieba结巴分词原理及用法解析
Nov 05 Python
python文件路径操作方法总结
Dec 21 Python
python 下载文件的几种方式分享
Apr 07 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
Python上下文管理器Content Manager
Jun 26 Python
Python pysnmp使用方法及代码实例
Aug 24 #Python
详解python tcp编程
Aug 24 #Python
Python rabbitMQ如何实现生产消费者模式
Aug 24 #Python
利用Python的folium包绘制城市道路图的实现示例
Aug 24 #Python
深入分析python 排序
Aug 24 #Python
超级实用的8个Python列表技巧
Aug 24 #Python
基于CentOS搭建Python Django环境过程解析
Aug 24 #Python
You might like
PHP下通过QRCode类库创建中间带网站LOGO的二维码
2014/07/12 PHP
php中count获取多维数组长度的方法
2014/11/03 PHP
PHPCMS手机站伪静态设置详细教程
2017/02/06 PHP
Yii2学习笔记之汉化yii设置表单的描述(属性标签attributeLabels)
2017/02/07 PHP
ThinkPHP实现转换数据库查询结果数据到对应类型的方法
2017/11/16 PHP
PHP asXML()函数讲解
2019/02/03 PHP
TP5.0框架实现无限极回复功能的方法分析
2019/05/04 PHP
PHP常用字符串函数用法实例总结
2020/06/04 PHP
深入理解JavaScript系列(39):设计模式之适配器模式详解
2015/03/04 Javascript
jQuery过滤HTML标签并高亮显示关键字的方法
2015/08/07 Javascript
Node.js的MongoDB驱动Mongoose基本使用教程
2016/03/01 Javascript
JavaScript的Ext JS框架中的GridPanel组件使用指南
2016/05/21 Javascript
Angularjs实现搜索关键字高亮显示效果
2017/01/17 Javascript
JS实现侧边栏鼠标经过弹出框+缓冲效果
2017/03/29 Javascript
jQuery中hover方法搭配css的hover选择器,实现选中元素突出显示方法
2017/05/08 jQuery
微信小程序实现滑动删除效果
2017/05/19 Javascript
详解react使用react-bootstrap当轮子造车
2017/08/15 Javascript
详解使用Typescript开发node.js项目(简单的环境配置)
2017/10/09 Javascript
JavaScript实现连连看连线算法
2019/01/05 Javascript
vue在自定义组件中使用v-model进行数据绑定的方法
2019/03/25 Javascript
JavaScript设计模式之门面模式原理与实现方法分析
2020/03/09 Javascript
[03:32]2014DOTA2西雅图邀请赛 CIS外卡赛赛前black专访
2014/07/09 DOTA
[43:57]LGD vs Mineski 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
Python sys.path详细介绍
2013/10/17 Python
python变量不能以数字打头详解
2016/07/06 Python
Python使用cookielib模块操作cookie的实例教程
2016/07/12 Python
python基于物品协同过滤算法实现代码
2018/05/31 Python
python爬虫超时的处理的实例
2018/12/19 Python
python如何快速拼接字符串
2020/10/28 Python
全面介绍python中很常用的单元测试框架unitest
2020/12/14 Python
沪江旗下的海量优质课程平台:沪江网校
2017/11/07 全球购物
"火柴棍式"程序员面试题
2014/03/16 面试题
在浏览器端如何得到服务器端响应的XML数据
2012/11/24 面试题
2014年开学第一课活动方案
2014/03/06 职场文书
我们的节日中秋活动方案
2014/08/19 职场文书
个人催款函范文
2015/06/24 职场文书