pytorch使用Variable实现线性回归


Posted in Python onMay 21, 2019

本文实例为大家分享了pytorch使用Variable实现线性回归的具体代码,供大家参考,具体内容如下

一、手动计算梯度实现线性回归

#导入相关包
import torch as t
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size = 8):
 #设置随机种子数,这样每次生成的随机数都是一样的
 t.manual_seed(10)
 #产生随机数据:y = 2*x+3,加上了一些噪声
 x = t.rand(batch_size,1) * 20
 #randn生成期望为0方差为1的正态分布随机数
 y = x * 2 + (1 + t.randn(batch_size,1)) * 3 
 return x,y
 
#查看生成数据的分布
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#随机初始化参数
w = t.rand(1,1)
b = t.zeros(1,1)
#学习率
lr = 0.001 
 
for i in range(10000):
 x,y = get_fake_data()
 
 #forward:计算loss
 y_pred = x.mm(w) + b.expand_as(y)
 
 #均方误差作为损失函数
 loss = 0.5 * (y_pred - y)**2 
 loss = loss.sum()
 
 #backward:手动计算梯度
 dloss = 1
 dy_pred = dloss * (y_pred - y)
 dw = x.t().mm(dy_pred)
 db = dy_pred.sum()
 
 #更新参数
 w.sub_(lr * dw)
 b.sub_(lr * db)
 
 if i%1000 == 0:
 #画图
 plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
 x1 = t.arange(0,20).float().view(-1,1)
 y1 = x1.mm(w) + b.expand_as(x1)
 plt.plot(x1.numpy(),y1.numpy()) #predicted
 plt.show()
 #plt.pause(0.5)
 print(w.squeeze(),b.squeeze())

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

二、自动梯度 计算梯度实现线性回归

#导入相关包
import torch as t
from torch.autograd import Variable as V
import matplotlib.pyplot as plt
 
#构造数据
def get_fake_data(batch_size=8):
 t.manual_seed(10) #设置随机数种子
 x = t.rand(batch_size,1) * 20
 y = 2 * x +(1 + t.randn(batch_size,1)) * 3
 return x,y
 
#查看产生的x,y的分布是什么样的
x,y = get_fake_data()
plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
 
#线性回归
 
#初始化随机参数
w = V(t.rand(1,1),requires_grad=True)
b = V(t.rand(1,1),requires_grad=True)
lr = 0.001
for i in range(8000):
 x,y = get_fake_data()
 x,y = V(x),V(y)
 y_pred = x * w + b
 loss = 0.5 * (y_pred-y)**2
 loss = loss.sum()
 
 #自动计算梯度
 loss.backward()
 #更新参数
 w.data.sub_(lr * w.grad.data)
 b.data.sub_(lr * b.grad.data)
 
 #梯度清零,不清零梯度会累加的
 w.grad.data.zero_()
 b.grad.data.zero_()
 
 if i%1000==0:
 #predicted
 x = t.arange(0,20).float().view(-1,1)
 y = x.mm(w.data) + b.data.expand_as(x)
 plt.plot(x.numpy(),y.numpy())
 
 #true data
 x2,y2 = get_fake_data()
 plt.scatter(x2.numpy(),y2.numpy())
 plt.show()
print(w.data[0],b.data[0])

pytorch使用Variable实现线性回归

显示的最后一张图如下所示:

pytorch使用Variable实现线性回归

用autograd实现的线性回归最大的不同点就在于利用autograd不需要手动计算梯度,可以自动微分。这一点不单是在深度在学习中,在许多机器学习的问题中都很有用。另外,需要注意的是每次反向传播之前要记得先把梯度清零,因为autograd求得的梯度是自动累加的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转换HTML到Text纯文本的方法
Jan 15 Python
Python最长公共子串算法实例
Mar 07 Python
python使用chardet判断字符串编码的方法
Mar 13 Python
用PyQt进行Python图形界面的程序的开发的入门指引
Apr 14 Python
Python实现将罗马数字转换成普通阿拉伯数字的方法
Apr 19 Python
PyQt实现界面翻转切换效果
Apr 20 Python
Python3正则匹配re.split,re.finditer及re.findall函数用法详解
Jun 11 Python
python+opencv像素的加减和加权操作的实现
Jul 14 Python
python网络爬虫 Scrapy中selenium用法详解
Sep 28 Python
Python中断多重循环的思路总结
Oct 04 Python
python 循环数据赋值实例
Dec 02 Python
Python 实现将某一列设置为str类型
Jul 14 Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 #Python
python+selenium实现简历自动刷新的示例代码
May 20 #Python
You might like
PHP中break及continue两个流程控制指令区别分析
2011/04/18 PHP
如何使用“PHP” 彩蛋进行敏感信息获取
2013/08/07 PHP
PHP使用PHPMailer发送邮件的简单使用方法
2013/11/12 PHP
PHP自定义函数获取汉字首字母的方法
2016/12/01 PHP
Json_decode 解析json字符串为NULL的解决方法(必看)
2017/02/17 PHP
PHP反射实际应用示例
2019/04/03 PHP
Add Formatted Data to a Spreadsheet
2007/06/12 Javascript
JavaScript 基础问答三
2008/12/03 Javascript
Js注册协议倒计时的小例子
2013/06/24 Javascript
js实现广告漂浮效果的小例子
2013/07/02 Javascript
JavaScript编程中的Promise使用大全
2015/07/28 Javascript
jQuery插件ajaxfileupload.js实现上传文件
2020/10/23 Javascript
学习Node.js模块机制
2016/10/17 Javascript
javascript基本数据类型及类型检测常用方法小结
2016/12/14 Javascript
js 实现省市区三级联动菜单效果
2017/02/20 Javascript
原生JS实现DOM加载完成马上执行JS代码的方法
2018/09/07 Javascript
Vue项目安装插件并保存
2019/01/28 Javascript
jQuery实现弹出层效果
2019/12/10 jQuery
在vue中实现echarts随窗体变化
2020/07/27 Javascript
vue使用canvas实现移动端手写签名
2020/09/22 Javascript
CentOS安装pillow报错的解决方法
2016/01/27 Python
python3+requests接口自动化session操作方法
2018/10/13 Python
使用pycharm设置控制台不换行的操作方法
2019/01/19 Python
python实现合并两个排序的链表
2019/03/03 Python
Python 给屏幕打印信息加上颜色的实现方法
2019/04/24 Python
python调用支付宝支付接口流程
2019/08/15 Python
python3用urllib抓取贴吧邮箱和QQ实例
2020/03/10 Python
基于PyQT实现区分左键双击和单击
2020/05/19 Python
项目考察欢迎辞
2014/01/17 职场文书
2014幼儿园班主任工作总结
2014/12/04 职场文书
交通安全温馨提示语
2015/07/14 职场文书
车间安全生产管理制度
2015/08/06 职场文书
2015年文秘个人工作总结
2015/10/14 职场文书
幼儿园开学家长寄语(2016秋季)
2015/12/03 职场文书
60句有关成长的名言
2019/09/04 职场文书
Vue项目打包、合并及压缩优化网页响应速度
2021/07/07 Vue.js