PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 测试实现方法
Dec 24 Python
python列表去重的二种方法
Feb 14 Python
详解Django通用视图中的函数包装
Jul 21 Python
Using Django with GAE Python 后台抓取多个网站的页面全文
Feb 17 Python
python使用电子邮件模块smtplib的方法
Aug 28 Python
Python实现PS滤镜Fish lens图像扭曲效果示例
Jan 29 Python
Python决策树和随机森林算法实例详解
Jan 30 Python
tensorflow获取变量维度信息
Mar 10 Python
python区块及区块链的开发详解
Jul 03 Python
Python使用循环神经网络解决文本分类问题的方法详解
Jan 16 Python
解决echarts中饼图标签重叠的问题
May 16 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
PHP中创建和验证哈希的简单方法实探
2015/07/06 PHP
Yii2.0 Basic代码中路由链接被转义的处理方法
2016/09/21 PHP
ThinkPHP发送邮件示例代码
2016/10/08 PHP
php+redis实现消息队列功能示例
2019/09/19 PHP
Laravel基础_关于view共享数据的示例讲解
2019/10/14 PHP
javascript自执行函数之伪命名空间封装法
2010/12/25 Javascript
用jquery实现的模拟QQ邮箱里的收件人选取及其他效果(一)
2011/01/06 Javascript
in.js 一个轻量级的JavaScript颗粒化模块加载和依赖关系管理解决方案
2011/07/26 Javascript
javascript写的简单的计算器,内容很多,方法实用,推荐
2011/12/29 Javascript
jquery交替变换颜色的三种方法 实例代码
2013/11/19 Javascript
nodejs开发环境配置与使用
2014/11/17 NodeJs
AngularJS中一般函数参数传递用法分析
2016/11/22 Javascript
Bootstrap源码解读导航(6)
2016/12/23 Javascript
Postman模拟发送带token的请求方法
2018/03/31 Javascript
vue项目使用axios发送请求让ajax请求头部携带cookie的方法
2018/09/26 Javascript
微信小程序云开发修改云数据库中的数据方法
2019/05/18 Javascript
element日历calendar组件上月、今天、下月、日历块点击事件及模板源码
2020/07/27 Javascript
解决vue cli4升级sass-loader(v8)后报错问题
2020/07/30 Javascript
vue 公共列表选择组件,引用Vant-UI的样式方式
2020/11/02 Javascript
简单介绍Python中的JSON使用
2015/04/28 Python
Python基于Socket实现的简单聊天程序示例
2017/08/05 Python
python3.6实现学生信息管理系统
2019/02/21 Python
通过PHP与Python代码对比的语法差异详解
2019/07/10 Python
Python测试Kafka集群(pykafka)实例
2019/12/23 Python
Python网络爬虫信息提取mooc代码实例
2020/03/06 Python
css3个性化字体_动力节点Java学院整理
2017/07/12 HTML / CSS
女装和独特珠宝:Sundance Catalog
2018/09/19 全球购物
一套Java笔试题
2016/08/20 面试题
商得四方公司面试题(gid+)
2014/04/30 面试题
入党思想汇报
2014/01/05 职场文书
优秀经理获奖感言
2014/03/04 职场文书
交警正风肃纪剖析材料
2014/10/29 职场文书
2015年为民办实事工作总结
2015/05/26 职场文书
学习师德师风的心得体会(2篇)
2019/10/08 职场文书
Java集成swagger文档组件
2021/06/28 Java/Android
Python类方法总结讲解
2021/07/26 Python