pytorch实现线性拟合方式


Posted in Python onJanuary 15, 2020

一维线性拟合

数据为y=4x+5加上噪音

结果:

pytorch实现线性拟合方式

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
from torch.autograd import Variable
import torch
from torch import nn
 
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
 
class LinearRegression(nn.Module):
 def __init__(self):
  super(LinearRegression, self).__init__()
  self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
 def forward(self, X):
  out = self.linear(X)
  return out
 
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
 
num_epochs = 1000
for epoch in range(num_epochs):
 inputs = Variable(X)
 target = Variable(Y)
 # 向前传播
 out = model(inputs)
 loss = criterion(out, target)
 
 # 向后传播
 optimizer.zero_grad() # 注意每次迭代都需要清零
 loss.backward()
 optimizer.step()
 
 if (epoch + 1) % 20 == 0:
  print('Epoch[{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss.item()))
model.eval()
predict = model(Variable(X))
predict = predict.data.numpy()
plt.plot(X.numpy(), Y.numpy(), 'ro', label='Original Data')
plt.plot(X.numpy(), predict, label='Fitting Line')
plt.show()

多维:

from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
 
POLY_DEGREE = 3
def make_features(x):
 """Builds features i.e. a matrix with columns [x, x^2, x^3]."""
 x = x.unsqueeze(1)
 return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
 
 
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
 
 
def f(x):
 return x.mm(W_target) + b_target.item()
def get_batch(batch_size=32):
 random = torch.randn(batch_size)
 x = make_features(random)
 y = f(x)
 return x, y
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
batch_x, batch_y = get_batch()
print(batch_x,batch_y)
for batch_idx in count(1):
 # Get data
 
 
 # Reset gradients
 fc.zero_grad()
 
 # Forward pass
 output = F.smooth_l1_loss(fc(batch_x), batch_y)
 loss = output.item()
 
 # Backward pass
 output.backward()
 
 # Apply gradients
 for param in fc.parameters():
  param.data.add_(-0.1 * param.grad.data)
 
 # Stop criterion
 if loss < 1e-3:
  break
 
 
def poly_desc(W, b):
 """Creates a string description of a polynomial."""
 result = 'y = '
 for i, w in enumerate(W):
  result += '{:+.2f} x^{} '.format(w, len(W) - i)
 result += '{:+.2f}'.format(b[0])
 return result
 
 
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))

以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python subprocess模块学习总结
Mar 13 Python
python实现识别相似图片小结
Feb 22 Python
Python实现ssh批量登录并执行命令
Oct 25 Python
python导出chrome书签到markdown文件的实例代码
Dec 27 Python
Python面向对象之静态属性、类方法与静态方法分析
Aug 24 Python
python词云库wordCloud使用方法详解(解决中文乱码)
Feb 17 Python
利用Python制作动态排名图的实现代码
Apr 09 Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 Python
Python中无限循环需要什么条件
May 27 Python
Python flask框架如何显示图像到web页面
Jun 03 Python
Python接收手机短信的代码整理
Aug 02 Python
Python实现异步IO的示例
Nov 05 Python
Python 支持向量机分类器的实现
Jan 15 #Python
pytorch-神经网络拟合曲线实例
Jan 15 #Python
Pytorch中的VGG实现修改最后一层FC
Jan 15 #Python
详解Python3 中的字符串格式化语法
Jan 15 #Python
用pytorch的nn.Module构造简单全链接层实例
Jan 14 #Python
pytorch三层全连接层实现手写字母识别方式
Jan 14 #Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
You might like
PHP编程中字符串处理的5个技巧小结
2007/11/13 PHP
php中用socket模拟http中post或者get提交数据的示例代码
2013/08/08 PHP
thinkPHP的表达式查询用法详解
2016/09/14 PHP
PHP实现的常规正则验证helper公共类完整实例
2017/04/27 PHP
tp5实现微信小程序多图片上传到服务器功能
2018/07/16 PHP
Laravel框架查询构造器 CURD操作示例
2019/09/04 PHP
jquery json 实例代码
2010/12/02 Javascript
jQuery学习笔记 操作jQuery对象 属性处理
2012/09/19 Javascript
js判断滚动条是否已到页面最底部或顶部实例
2014/11/20 Javascript
jQuery中position()方法用法实例
2015/01/16 Javascript
javascript关于运动的各种问题经典总结
2015/04/27 Javascript
jQuery ajax分页插件实例代码
2016/01/27 Javascript
js实现打地鼠小游戏
2017/02/13 Javascript
详解微信小程序Radio选中样式切换
2017/07/06 Javascript
BootstrapTable加载按钮功能实例代码详解
2017/09/22 Javascript
layui使用表格渲染获取行数据的例子
2019/09/13 Javascript
IDEA安装vue插件图文详解
2019/09/26 Javascript
如何基于JS截获动态代码
2019/12/25 Javascript
解决antd Form 表单校验方法无响应的问题
2020/10/27 Javascript
python3访问sina首页中文的处理方法
2014/02/24 Python
Python使用cookielib模块操作cookie的实例教程
2016/07/12 Python
Python的标准模块包json详解
2017/03/13 Python
python与sqlite3实现解密chrome cookie实例代码
2018/01/20 Python
Python中enumerate()函数编写更Pythonic的循环
2018/03/06 Python
python微元法计算函数曲线长度的方法
2018/11/08 Python
windows下安装Python虚拟环境virtualenvwrapper-win
2019/06/14 Python
Django 重写用户模型的实现
2019/07/29 Python
PyQt5 closeEvent关闭事件退出提示框原理解析
2020/01/08 Python
Python中使用filter过滤列表的一个小技巧分享
2020/05/02 Python
python安装mysql的依赖包mysql-python操作
2021/01/01 Python
Koral官方网站:女性时尚运动服
2019/04/10 全球购物
毕业生幼师求职自荐信
2013/10/01 职场文书
中文专业毕业生自荐信
2014/05/24 职场文书
钱学森观后感
2015/06/04 职场文书
惊天动地观后感
2015/06/10 职场文书
欧也妮葛朗台读书笔记
2015/06/30 职场文书