PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现模拟按键,自动翻页看u17漫画
Mar 17 Python
使用wxPython获取系统剪贴板中的数据的教程
May 06 Python
Python中type的构造函数参数含义说明
Jun 21 Python
Python实现的Excel文件读写类
Jul 30 Python
Python实现JSON反序列化类对象的示例
Jan 31 Python
利用python list完成最简单的DB连接池方法
Aug 09 Python
Python数据库小程序源代码
Sep 15 Python
python [:3] 实现提取数组中的数
Nov 27 Python
基于pytorch 预训练的词向量用法详解
Jan 06 Python
基于python制作简易版学生信息管理系统
Apr 20 Python
python实战之一步一步教你绘制小猪佩奇
Apr 22 Python
python调用ffmpeg命令行工具便捷操作视频示例实现过程
Nov 01 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
聊天室php&mysql(四)
2006/10/09 PHP
PHP 模板高级篇总结
2006/12/21 PHP
php的curl封装类用法实例
2014/11/07 PHP
php通过记录IP来防止表单重复提交方法分析
2014/12/16 PHP
Symfony页面的基本创建实例详解
2015/01/26 PHP
JavaScript中的变量声明早于赋值分析
2012/03/01 Javascript
为JS扩展Array.prototype.indexOf引发的问题及解决办法
2015/01/21 Javascript
使用AngularJS来实现HTML页面嵌套的方法
2015/06/17 Javascript
js实现文字闪烁特效的方法
2015/12/17 Javascript
Javascript this 函数深入详解
2016/12/13 Javascript
利用JavaScript实现拖拽改变元素大小
2016/12/14 Javascript
JavaScript自定义分页样式
2017/01/17 Javascript
d3.js实现立体柱图的方法详解
2017/04/28 Javascript
vue2.0结合Element实现select动态控制input禁用实例
2017/05/12 Javascript
node app 打包工具pkg的具体使用
2019/01/17 Javascript
深入解析koa之中间件流程控制
2019/06/17 Javascript
vue 验证码界面实现点击后标灰并设置div按钮不可点击状态
2019/10/28 Javascript
[01:26]DOTA2荣耀之路2:iG,China
2018/05/24 DOTA
python中sets模块的用法实例
2014/09/30 Python
对Django url的几种使用方式详解
2019/08/06 Python
解决Python中回文数和质数的问题
2019/11/24 Python
Python读取表格类型文件代码实例
2020/02/17 Python
Django中文件上传和文件访问微项目的方法
2020/04/27 Python
Python基于进程池实现多进程过程解析
2020/04/30 Python
python爬虫利器之requests库的用法(超全面的爬取网页案例)
2020/12/17 Python
全球最大的网上自行车商店:Chain Reaction Cycles
2016/12/02 全球购物
常用UNIX 命令(Linux的常用命令)
2015/12/26 面试题
进修护士自我鉴定
2013/10/14 职场文书
运动会100米解说词
2014/01/23 职场文书
教育学习自我评价
2014/02/03 职场文书
保证金退回承诺函格式
2015/01/21 职场文书
员工试用期转正自我评价
2015/03/10 职场文书
婚宴父亲致辞
2015/07/27 职场文书
python pyhs2 的安装操作
2021/04/07 Python
七个Python必备的GUI库
2021/04/27 Python
Django migrate报错的解决方案
2021/05/20 Python