PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python计算程序运行时间的方法
Dec 13 Python
Python的Bottle框架中获取制定cookie的教程
Apr 24 Python
为Python的web框架编写MVC配置来使其运行的教程
Apr 30 Python
Python3对称加密算法AES、DES3实例详解
Dec 06 Python
python 判断文件还是文件夹的简单实例
Jun 10 Python
Python Django Cookie 简单用法解析
Aug 13 Python
python同步windows和linux文件
Aug 29 Python
python 比较字典value的最大值的几种方法
Apr 17 Python
Python-jenkins模块获取jobs的执行状态操作
May 12 Python
python连接mysql数据库并读取数据的实现
Sep 25 Python
python实现Nao机器人的单目测距
Sep 04 Python
python机器学习Github已达8.9Kstars模型解释器LIME
Nov 23 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
杏林同学录(二)
2006/10/09 PHP
带密匙的php加密解密示例分享
2014/01/29 PHP
制作安全性高的PHP网站的几个实用要点
2014/12/30 PHP
完美解决phpdoc导出文档中@package的warning及Error的错误
2016/05/17 PHP
js网页中的(运行代码)功能实现思路
2013/02/04 Javascript
可简单避免的三个JS发布错误的详细介绍
2013/08/02 Javascript
jquery获取一组checkbox的值(实例代码)
2013/11/04 Javascript
用C/C++来实现 Node.js 的模块(一)
2014/09/24 Javascript
JavaScript中判断整数的多种方法总结
2014/11/08 Javascript
莱鸟介绍javascript onclick事件
2016/01/06 Javascript
探讨:JavaScript ECAMScript5 新特性之get/set访问器
2016/05/05 Javascript
判断js的Array和Object的实现方法
2016/08/29 Javascript
关于TypeScript中import JSON的正确姿势详解
2017/07/25 Javascript
jQuery模拟爆炸倒计时功能实例代码
2017/08/21 jQuery
三分钟学会用ES7中的Async/Await进行异步编程
2018/06/14 Javascript
[04:47]DOTA2-潍坊风行电子俱乐部探秘
2014/08/08 DOTA
Python3使用requests登录人人影视网站的方法
2016/05/11 Python
浅谈pandas中DataFrame关于显示值省略的解决方法
2018/04/08 Python
python如何创建TCP服务端和客户端
2018/08/26 Python
python批量赋值操作实例
2018/10/22 Python
Django中的forms组件实例详解
2018/11/08 Python
解决Python下json.loads()中文字符出错的问题
2018/12/19 Python
对python使用telnet实现弱密码登录的方法详解
2019/01/26 Python
在Django的View中使用asyncio的方法
2019/07/12 Python
python批量修改ssh密码的实现
2019/08/08 Python
python基础 range的用法解析
2019/08/23 Python
Python使用文件操作实现一个XX信息管理系统的示例
2020/07/02 Python
详解pycharm配置python解释器的问题
2020/10/15 Python
css3的动画特效之动画序列(animation)
2017/12/22 HTML / CSS
以下为Windows NT 下的32 位C++程序,请计算sizeof 的值
2016/12/07 面试题
实习老师个人总结的自我评价
2013/09/28 职场文书
三年级小学生评语
2014/04/22 职场文书
2014年党员承诺书范文
2014/05/20 职场文书
争先创优演讲稿
2014/09/15 职场文书
六一儿童节标语
2014/10/08 职场文书
win server2012 r2服务器共享文件夹如何设置
2022/06/21 Servers