PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的生成自我描述脚本分享(很有意思的程序)
Jul 18 Python
跟老齐学Python之关于循环的小伎俩
Oct 02 Python
用Python生成器实现微线程编程的教程
Apr 13 Python
Python中用Spark模块的使用教程
Apr 13 Python
Python中使用ElementTree解析XML示例
Jun 02 Python
Python网络爬虫出现乱码问题的解决方法
Jan 05 Python
Python入门学习指南分享
Apr 11 Python
解决vscode python print 输出窗口中文乱码的问题
Dec 03 Python
200行python代码实现2048游戏
Jul 17 Python
python使用Matplotlib改变坐标轴的默认位置
Oct 18 Python
关于tensorflow的几种参数初始化方法小结
Jan 04 Python
对python中arange()和linspace()的区别说明
May 03 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
php中cookie的作用域
2008/03/27 PHP
Drupal7 form表单二次开发要点与实例
2014/03/02 PHP
php中的ini配置原理详解
2014/10/14 PHP
php验证身份证号码正确性的函数
2016/07/20 PHP
微信小程序发送订阅消息的方法(php 为例)
2019/10/30 PHP
TP5框架使用QueryList采集框架爬小说操作示例
2020/03/26 PHP
PHP var关键字相关原理及使用实例解析
2020/07/11 PHP
javascript 变量作用域 代码分析
2009/06/26 Javascript
jQuery学习笔记之DOM对象和jQuery对象
2010/12/22 Javascript
javascript中xml操作实现代码
2011/11/21 Javascript
ie6下png图片背景不透明的解决办法使用js实现
2013/01/11 Javascript
Bootstrap每天必学之导航条(二)
2016/03/01 Javascript
相册展示PhotoSwipe.js插件实现
2016/08/25 Javascript
探讨AngularJs中ui.route的简单应用
2016/11/16 Javascript
js实现tab切换效果
2017/02/16 Javascript
JS实现快递单打印功能【推荐】
2018/06/21 Javascript
vue-cli的build的文件夹下没有dev-server.js文件配置mock数据的方法
2019/04/17 Javascript
了解JavaScript中let语句
2019/05/30 Javascript
jQuery实现文本显示一段时间后隐藏的方法分析
2019/06/20 jQuery
Vue项目实现简单的权限控制管理功能
2019/07/17 Javascript
JS自定义滚动条效果
2020/03/13 Javascript
[02:52]2014DOTA2西雅图国际邀请赛 CIS战队巡礼
2014/07/07 DOTA
[03:21]【TI9纪实】Old Boys
2019/08/23 DOTA
利用python爬取散文网的文章实例教程
2017/06/18 Python
疯狂上涨的Python 开发者应从2.x还是3.x着手?
2017/11/16 Python
对Xpath 获取子标签下所有文本的方法详解
2019/01/02 Python
Django实现发送邮件找回密码功能
2019/08/12 Python
python 检测图片是否有马赛克
2020/12/01 Python
html5使用canvas绘制文字特效
2014/12/15 HTML / CSS
店长助理岗位职责
2013/12/13 职场文书
团日活动总结
2014/04/28 职场文书
公务员年度个人总结
2015/02/12 职场文书
大学生活感想
2015/08/10 职场文书
Python 阶乘详解
2021/10/05 Python
Python+Tkinter打造签名设计工具
2022/04/01 Python
ubuntu端向日葵键盘输入卡顿问题及解决
2022/12/24 Servers