pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中利用sqrt()方法进行平方根计算的教程
May 15 Python
python用pickle模块实现“增删改查”的简易功能
Jun 07 Python
Python实现XML文件解析的示例代码
Feb 05 Python
python numpy和list查询其中某个数的个数及定位方法
Jun 27 Python
解决sublime+python3无法输出中文的问题
Dec 12 Python
对python文件读写的缓冲行为详解
Feb 13 Python
分析运行中的 Python 进程详细解析
Jun 22 Python
Pytorch之view及view_as使用详解
Dec 31 Python
基于python实现语音录入识别代码实例
Jan 17 Python
Python运行DLL文件的方法
Jan 17 Python
python3爬虫中引用Queue的实例讲解
Nov 24 Python
详解Django自定义图片和文件上传路径(upload_to)的2种方式
Dec 01 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
Youku 视频绝对地址获取的方法详解
2013/06/26 PHP
PHP使用Session遇到的一个Permission denied Notice解决办法
2014/07/30 PHP
PHP编程实现多维数组按照某个键值排序的方法小结【2种方法】
2017/04/27 PHP
在php的yii2框架中整合hbase库的方法
2018/09/20 PHP
JavaScript中的几个关键概念的理解-原型链的构建
2011/05/12 Javascript
仅IE不支持setTimeout/setInterval函数的第三个以上参数
2011/05/25 Javascript
js返回前一页刷新本页重载页面
2014/07/29 Javascript
JavaScript数组函数unshift、shift、pop、push使用实例
2014/08/27 Javascript
21个JavaScript事件(Events)属性汇总
2014/12/02 Javascript
Jquery 实现checkbox全选方法
2015/01/28 Javascript
jQuery支持动态参数将函数绑定到事件上的方法
2015/03/17 Javascript
jquery validate demo 基础
2015/10/29 Javascript
jquery实现全选、反选、获得所有选中的checkbox
2020/09/13 Javascript
全屏js头像上传插件源码高清版
2016/03/29 Javascript
jQuery实现背景弹性滚动的导航效果
2016/06/01 Javascript
nodejs处理图片的中间件node-images详解
2017/05/08 NodeJs
AngularJs导出数据到Excel的示例代码
2017/08/11 Javascript
vue滚动轴插件better-scroll使用详解
2017/10/17 Javascript
关于Angularjs中跨域设置白名单问题
2018/04/17 Javascript
localstorage实现带过期时间的缓存功能
2019/06/28 Javascript
使用python搭建Django应用程序步骤及版本冲突问题解决
2013/11/19 Python
跟老齐学Python之使用Python查询更新数据库
2014/11/25 Python
Python实现的最近最少使用算法
2015/07/10 Python
python3实现ftp服务功能(客户端)
2017/03/24 Python
Python实现的读写json文件功能示例
2018/06/05 Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
2020/01/14 Python
Python通过递归函数输出嵌套列表元素
2020/10/15 Python
如何用 Python 处理不平衡数据集
2021/01/04 Python
html5 canvas绘制网络字体的常用方法
2019/08/26 HTML / CSS
lululemon美国官网:瑜伽服+跑步装备
2018/11/16 全球购物
行政管理专业推荐信
2013/11/02 职场文书
2014年财务工作总结与计划
2014/12/08 职场文书
2015年教师节演讲稿范文
2015/03/19 职场文书
MySQL系列之三 基础篇
2021/07/02 MySQL
Go归并排序算法的实现方法
2022/04/06 Golang
python区块链持久化和命令行接口实现简版
2022/05/25 Python