PyTorch学习笔记之回归实战


Posted in Python onMay 28, 2018

本文主要是用PyTorch来实现一个简单的回归任务。

编辑器:spyder

1.引入相应的包及生成伪数据

import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable

# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

# 变为Variable
x, y = Variable(x), Variable(y)

其中torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze给伪数据添加一个维度,dim表示添加在第几维。torch.rand返回的是[0,1)之间的均匀分布。

2.绘制数据图像

在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:

# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

PyTorch学习笔记之回归实战

3.建立神经网络

class Net(torch.nn.Module):
 def __init__(self, n_feature, n_hidden, n_output):
  super(Net, self).__init__()
  self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
  self.predict = torch.nn.Linear(n_hidden, n_output) # output layer

 def forward(self, x):
  x = F.relu(self.hidden(x))  # activation function for hidden layer
  x = self.predict(x)    # linear output
  return x

net = Net(n_feature=1, n_hidden=10, n_output=1)  # define the network
print(net) # net architecture

一般神经网络的类都继承自torch.nn.Module__init__()和forward()两个函数是自定义类的主要函数。在__init__()中都要添加一句super(Net, self).__init__(),这是固定的标准写法,用于继承父类的初始化函数。__init__()中只是对神经网络的模块进行了声明,真正的搭建是在forwad()中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。

如果想查看网络结构,可以用print()函数直接打印网络。本文的网络结构输出如下:

Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)

4.训练网络

# 训练100次
for t in range(100):
 prediction = net(x)  # input x and predict based on x

 loss = loss_func(prediction, y)  # 一定要是输出在前,标签在后 (1. nn output, 2. target)

 optimizer.zero_grad() # clear gradients for next train
 loss.backward()   # backpropagation, compute gradients
 optimizer.step()  # apply gradients

训练网络之前我们需要先定义优化器和损失函数。torch.optim包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()传入优化器中,并设置学习率(一般小于1)。

由于这里是回归任务,我们选择torch.nn.MSELoss()作为损失函数。

由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()把梯度置零,然后再后向传播及更新。

5.可视化训练过程

plt.ion() # something about plotting

for t in range(100):
 ...

 if t % 5 == 0:
  # plot and show learning process
  plt.cla()
  plt.scatter(x.data.numpy(), y.data.numpy())
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
  plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
  plt.pause(0.1)

plt.ioff()
plt.show()

6.运行结果

PyTorch学习笔记之回归实战

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
常见python正则用法的简单实例
Jun 21 Python
Python Socket使用实例
Dec 18 Python
详解Python自建logging模块
Jan 29 Python
Python实现的生产者、消费者问题完整实例
May 30 Python
让Python脚本暂停执行的几种方法(小结)
Jul 11 Python
django 类视图的使用方法详解
Jul 24 Python
python多线程扫描端口(线程池)
Sep 04 Python
python验证码图片处理(二值化)
Nov 01 Python
Python3以GitHub为例来实现模拟登录和爬取的实例讲解
Jul 30 Python
简述 Python 的类和对象
Aug 21 Python
python基于scrapy爬取京东笔记本电脑数据并进行简单处理和分析
Apr 14 Python
pandas进行数据输入和输出的方法详解
Mar 23 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 #Python
Python实现爬虫爬取NBA数据功能示例
May 28 #Python
Django+Ajax+jQuery实现网页动态更新的实例
May 28 #Python
Python实现合并两个列表的方法分析
May 28 #Python
django js实现部分页面刷新的示例代码
May 28 #Python
Django项目中用JS实现加载子页面并传值的方法
May 28 #Python
Python面向对象类继承和组合实例分析
May 28 #Python
You might like
php 输出双引号"与单引号'的方法
2010/05/09 PHP
深入php-fpm的两种进程管理模式详解
2013/06/03 PHP
PHP临时文件的安全性分析
2014/07/04 PHP
PHP中的Trait 特性及作用
2016/04/03 PHP
简单PHP会话(session)说明介绍
2016/08/21 PHP
Thinkphp结合AJAX长轮询实现PC与APP推送详解
2017/07/31 PHP
Laravel自定义 封装便捷返回Json数据格式的引用方法
2019/09/29 PHP
有一段有意思的代码-javascript现实多行信息
2007/08/26 Javascript
JavaScript 定义function的三种方式小结
2009/10/16 Javascript
JS获取dom 对象 ajax操作 读写cookie函数
2009/11/18 Javascript
MooBox 基于Mootools的对话框插件
2012/01/20 Javascript
javascript实现checkBox的全选,反选与赋值
2015/03/12 Javascript
基于jQuery实现表格内容的筛选功能
2016/08/21 Javascript
JavaScript鼠标特效大全
2016/09/13 Javascript
nodejs利用http模块实现银行卡所属银行查询和骚扰电话验证示例
2016/12/30 NodeJs
MUI 上拉刷新/下拉加载功能实例代码
2017/04/13 Javascript
vue 的keep-alive缓存功能的实现
2018/03/22 Javascript
解决使用bootstrap的dropdown部件时报错:error:Bootstrap dropdown require Popper.js问题
2018/08/30 Javascript
vue轮播组件实现$children和$parent 附带好用的gif录制工具
2019/09/26 Javascript
[01:29]Ti4循环赛第三日精彩回顾
2014/07/13 DOTA
[01:08:00]Fnatic vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
Python编程使用tkinter模块实现计算器软件完整代码示例
2017/11/29 Python
一篇文章快速了解Python的GIL
2018/01/12 Python
Django项目开发中cookies和session的常用操作分析
2018/07/03 Python
使用Python批量修改文件名的代码实例
2019/01/24 Python
利用pyuic5将ui文件转换为py文件的方法
2019/06/19 Python
django框架模板中定义变量(set variable in django template)的方法分析
2019/06/24 Python
Python3 pandas 操作列表实例详解
2019/09/23 Python
Python修改列表值问题解决方案
2020/03/06 Python
Python3自带工具2to3.py 转换 Python2.x 代码到Python3的操作
2021/03/03 Python
优质飞蝇钓和渔具:RiverBum
2020/05/10 全球购物
应用艺术毕业生的自我评价
2013/12/04 职场文书
人事专员的岗位职责
2014/03/01 职场文书
2015元旦晚会主持人开场白+结束语
2014/12/14 职场文书
Vue3.0 手写放大镜效果
2021/07/25 Vue.js
Spring Boot DevTools 全局配置学习指南
2022/03/31 Java/Android