PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表生成器的循环技巧分享
Mar 06 Python
python生成随机密码或随机字符串的方法
Jul 03 Python
python3实现域名查询和whois查询功能
Jun 21 Python
Python3内置模块random随机方法小结
Jul 13 Python
django基础学习之send_mail功能
Aug 07 Python
Django ORM多对多查询方法(自定义第三张表&ManyToManyField)
Aug 09 Python
python根据多个文件名批量查找文件
Aug 13 Python
Tensorflow实现部分参数梯度更新操作
Jan 23 Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 Python
Python+OpenCV图像处理—— 色彩空间转换
Oct 22 Python
python和C++共享内存传输图像的示例
Oct 27 Python
教你怎么用PyCharm为同一服务器配置多个python解释器
May 31 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
PHP 超链接 抓取实现代码
2009/06/29 PHP
php如何控制用户对图片的访问 PHP禁止图片盗链
2016/03/25 PHP
PHP中Array相关函数简介
2016/07/03 PHP
掌握PHP垃圾回收机制详解
2019/03/13 PHP
css+js实现部分区域高亮可编辑遮罩层
2014/03/04 Javascript
JS 获取鼠标左右键的键值方法
2014/10/11 Javascript
javascript正则表达式基础知识入门
2015/04/20 Javascript
理解javascript中的严格模式
2016/02/01 Javascript
深入理解JS DOM事件机制
2016/08/06 Javascript
Vue.js bootstrap前端实现分页和排序
2017/03/10 Javascript
JavaScript实现封闭区域布尔运算的示例代码
2018/06/25 Javascript
jQuery实现获取动态添加的标签对象示例
2018/06/28 jQuery
在vue-cli 3中给stylus、sass样式传入共享的全局变量
2019/08/12 Javascript
Vuex modules模式下mapState/mapMutations的操作实例
2019/10/17 Javascript
布同 Python中文问题解决方法(总结了多位前人经验,初学者必看)
2011/03/13 Python
python实现各进制转换的总结大全
2017/06/18 Python
python学习笔记--将python源文件打包成exe文件(pyinstaller)
2018/05/26 Python
对python中if语句的真假判断实例详解
2019/02/18 Python
Python参数解析模块sys、getopt、argparse使用与对比分析
2019/04/02 Python
Keras—embedding嵌入层的用法详解
2020/06/10 Python
Python unittest装饰器实现原理及代码
2020/09/08 Python
python RSA加密的示例
2020/12/09 Python
深入了解canvas在移动端绘制模糊的问题解决
2019/04/30 HTML / CSS
AmazeUI的JS表单验证框架实战示例分享
2020/08/21 HTML / CSS
联想印度官方网上商店:Lenovo India
2019/08/24 全球购物
舞蹈专业大学生职业规划范文
2014/03/12 职场文书
酒店采购员岗位职责
2014/03/14 职场文书
运动员口号
2014/06/09 职场文书
高中军训的心得体会
2014/09/01 职场文书
教师三严三实学习心得体会
2014/10/11 职场文书
数据结构课程设计心得体会
2016/01/15 职场文书
python爬虫请求库httpx和parsel解析库的使用测评
2021/05/10 Python
idea编译器vue缩进报错问题场景分析
2021/07/04 Vue.js
python自动化八大定位元素讲解
2021/07/09 Python
第四次工业革命,打工人与机器人的竞争
2022/04/21 数码科技
5个pandas调用函数的方法让数据处理更加灵活自如
2022/04/24 Python