PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python与Redis的连接教程
Apr 22 Python
Python实现快速多线程ping的方法
Jul 15 Python
python 根据pid杀死相应进程的方法
Jan 16 Python
python实现两个文件合并功能
Apr 01 Python
python逐行读写txt文件的实例讲解
Apr 03 Python
详解Python 数据库的Connection、Cursor两大对象
Jun 25 Python
python实现从pdf文件中提取文本,并自动翻译的方法
Nov 28 Python
解决python中无法自动补全代码的问题
Dec 04 Python
linux查找当前python解释器的位置方法
Feb 20 Python
详解用Python实现自动化监控远程服务器
May 18 Python
详解Tensorflow不同版本要求与CUDA及CUDNN版本对应关系
Aug 04 Python
python装饰器三种装饰模式的简单分析
Sep 04 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
PHP使用JSON和将json还原成数组
2015/02/12 PHP
php数据库的增删改查 php与javascript之间的交互
2017/08/31 PHP
PHP类的自动加载与命名空间用法实例分析
2020/06/05 PHP
JavaScript自定义DateDiff函数(兼容所有浏览器)
2012/03/01 Javascript
jquery.blockUI.js上传滚动等待效果实现思路及代码
2013/03/18 Javascript
基于JavaScript 下namespace 功能的简单分析
2013/07/05 Javascript
JQuery 获取json数据$.getJSON方法的实例代码
2013/08/02 Javascript
jquery 设置元素相对于另一个元素的top值(实例代码)
2013/11/06 Javascript
JavaScript的各种常见函数定义方法
2014/09/16 Javascript
javascript使用call调用微信API
2014/12/15 Javascript
简单总结JavaScript中的String字符串类型
2016/05/26 Javascript
浅谈AngularJs指令之scope属性详解
2016/10/24 Javascript
BootStrap Fileinput的使用教程
2016/12/30 Javascript
webpack3+React 的配置全解
2017/08/21 Javascript
理理Vue细节(推荐)
2019/04/16 Javascript
D3.js(v3)+react 实现带坐标与比例尺的柱形图 (V3版本)
2019/05/09 Javascript
jQuery实现tab栏切换效果
2020/12/22 jQuery
在Python下利用OpenCV来旋转图像的教程
2015/04/16 Python
win与linux系统中python requests 安装
2016/12/04 Python
Anaconda 离线安装 python 包的操作方法
2018/06/11 Python
python读取一个目录下所有txt里面的内容方法
2018/06/23 Python
一文了解Python并发编程的工程实现方法
2019/05/31 Python
python 非线性规划方式(scipy.optimize.minimize)
2020/02/11 Python
Django如何在不停机的情况下创建索引
2020/08/02 Python
python使用隐式循环快速求和的实现示例
2020/09/11 Python
Omio中国:全欧洲低价大巴、火车和航班搜索和比价
2018/08/09 全球购物
美国在线宠物商店:Chewy
2019/01/12 全球购物
性能测试工程师的面试题
2015/02/20 面试题
促销活动总结报告
2014/04/26 职场文书
基层工作经验证明样本
2014/11/16 职场文书
入党培养人考察意见
2015/06/08 职场文书
《废话连篇——致新手》——chinapizza
2022/04/05 无线电
mysql查找连续出现n次以上的数字
2022/05/11 MySQL
Python Matplotlib绘制动画的代码详解
2022/05/30 Python
Echarts如何重新渲染实例详解
2022/05/30 Javascript
MySQL8.0 Undo Tablespace管理详解
2022/06/16 MySQL