PyTorch搭建一维线性回归模型(二)


Posted in Python onMay 22, 2019

PyTorch基础入门二:PyTorch搭建一维线性回归模型

1)一维线性回归模型的理论基础

给定数据集PyTorch搭建一维线性回归模型(二),线性回归希望能够优化出一个好的函数PyTorch搭建一维线性回归模型(二),使得PyTorch搭建一维线性回归模型(二)能够和PyTorch搭建一维线性回归模型(二)尽可能接近。

如何才能学习到参数PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)呢?很简单,只需要确定如何衡量PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二)之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:PyTorch搭建一维线性回归模型(二)。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到PyTorch搭建一维线性回归模型(二)PyTorch搭建一维线性回归模型(二),使得:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:

PyTorch搭建一维线性回归模型(二)

PyTorch搭建一维线性回归模型(二)

求解以上两式,我们就可以得到最优解。

2)代码实现

首先,我们需要“制造”出一些数据集:

import torch
import matplotlib.pyplot as plt
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()

我们想要拟合的一维回归模型是PyTorch搭建一维线性回归模型(二)。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。

上面人为制造出来的数据集的分布如下:

PyTorch搭建一维线性回归模型(二)

有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。

class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()

然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

接下来,开始进行模型的训练。

num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))

首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。

接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:

model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

其拟合结果如下图:

PyTorch搭建一维线性回归模型(二)

附上完整代码:

# !/usr/bin/python
# coding: utf8
# @Time  : 2018-07-28 18:40
# @Author : Liam
# @Email  : luyu.real@qq.com
# @Software: PyCharm
#            .::::.
#           .::::::::.
#           :::::::::::
#         ..:::::::::::'
#        '::::::::::::'
#         .::::::::::
#      '::::::::::::::..
#         ..::::::::::::.
#        ``::::::::::::::::
#        ::::``:::::::::'    .:::.
#        ::::'  ':::::'    .::::::::.
#       .::::'   ::::   .:::::::'::::.
#      .:::'    ::::: .:::::::::' ':::::.
#      .::'    :::::.:::::::::'   ':::::.
#     .::'     ::::::::::::::'     ``::::.
#   ...:::      ::::::::::::'       ``::.
#   ```` ':.     ':::::::::'         ::::..
#            '.:::::'          ':'````..
#           美女保佑 永无BUG
 
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
 
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
 
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
  def __init__(self):
    super(LinearRegression, self).__init__()
    self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
  def forward(self, x):
    out = self.linear(x)
    return out
 
if torch.cuda.is_available():
  model = LinearRegression().cuda()
else:
  model = LinearRegression()
 
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
  if torch.cuda.is_available():
    inputs = Variable(x).cuda()
    target = Variable(y).cuda()
  else:
    inputs = Variable(x)
    target = Variable(y)
 
  # 向前传播
  out = model(inputs)
  loss = criterion(out, target)
 
  # 向后传播
  optimizer.zero_grad() # 注意每次迭代都需要清零
  loss.backward()
  optimizer.step()
 
  if (epoch+1) %20 == 0:
    print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
  predict = model(Variable(x).cuda())
  predict = predict.data.cpu().numpy()
else:
  predict = model(Variable(x))
  predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pandas 将list切分后存入DataFrame中的实例
Jul 03 Python
PyCharm配置mongo插件的方法
Nov 30 Python
Python基于机器学习方法实现的电影推荐系统实例详解
Jun 25 Python
Python正则表达式匹配日期与时间的方法
Jul 07 Python
python 进程的几种创建方式详解
Aug 29 Python
Python 脚本拉取 Docker 镜像问题
Nov 10 Python
python解析xml文件方式(解析、更新、写入)
Mar 05 Python
详解python环境安装selenium和手动下载安装selenium的方法
Mar 17 Python
PyTorch加载自己的数据集实例详解
Mar 18 Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 Python
Python request post上传文件常见要点
Nov 20 Python
Python内置类型集合set和frozenset的使用详解
Apr 26 Python
PyTorch基本数据类型(一)
May 22 #Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
You might like
海贼王动画变成“真人”后,凯多神还原,雷利太帅了!
2020/04/09 日漫
php.ini 中文版
2006/10/28 PHP
PHP跨时区(UTC时间)应用解决方案
2013/01/11 PHP
laravel 5 实现模板主题功能
2015/03/02 PHP
使用laravel指定日志文件记录任意日志
2019/10/17 PHP
php设计模式之原型模式分析【星际争霸游戏案例】
2020/03/23 PHP
javascript深入理解js闭包
2010/07/03 Javascript
JavaScript中最简洁的编码html字符串的方法
2014/10/11 Javascript
JavaScript常用脚本汇总(三)
2015/03/04 Javascript
javascript字符串循环匹配实例分析
2015/07/17 Javascript
原生JS实现图片懒加载(lazyload)实例
2017/06/13 Javascript
AugularJS从入门到实践(必看篇)
2017/07/10 Javascript
Vue数组更新及过滤排序功能
2017/08/10 Javascript
Angular实现搜索框及价格上下限功能
2018/01/19 Javascript
微信小程序实现自定义加载图标功能
2018/07/19 Javascript
vue图片上传本地预览组件使用详解
2019/02/20 Javascript
layui表格数据复选框回显设置方法
2019/09/13 Javascript
JavaScript中0、空字符串、'0'是true还是false的知识点分享
2019/09/16 Javascript
详解javascript脚本何时会被执行
2021/02/05 Javascript
[54:15]DOTA2-DPC中国联赛 正赛 DLG vs Dragon BO3 第二场2月1日
2021/03/11 DOTA
python+opencv轮廓检测代码解析
2018/01/05 Python
python 生成图形验证码的方法示例
2018/11/11 Python
python爬取指定微信公众号文章
2018/12/20 Python
对python 自定义协议的方法详解
2019/02/13 Python
python 如何将数据写入本地txt文本文件的实现方法
2019/09/11 Python
Win10+GPU版Pytorch1.1安装的安装步骤
2019/09/27 Python
django创建超级用户时指定添加其它字段方式
2020/05/14 Python
python按照list中字典的某key去重的示例代码
2020/10/13 Python
牧马人澳大利亚官网:Wrangler澳大利亚
2019/10/08 全球购物
小学国旗下的演讲稿
2014/08/28 职场文书
2016情人节宣传语
2015/07/14 职场文书
回复函范文
2015/07/14 职场文书
2019安全宣传标语大全
2019/08/14 职场文书
互联网创业商业模式以及赚钱法则有哪些?
2019/10/12 职场文书
SpringCloud Alibaba项目实战之nacos-server服务搭建过程
2021/06/21 Java/Android
Android学习之BottomSheetDialog组件的使用
2022/06/21 Java/Android