pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
如何将python中的List转化成dictionary
Aug 15 Python
Python实现的递归神经网络简单示例
Aug 11 Python
Python升级导致yum、pip报错的解决方法
Sep 06 Python
python3.x实现base64加密和解密
Mar 28 Python
Python使用sklearn库实现的各种分类算法简单应用小结
Jul 04 Python
Python目录和文件处理总结详解
Sep 02 Python
使用python 计算百分位数实现数据分箱代码
Mar 03 Python
Django User 模块之 AbstractUser 扩展详解
Mar 11 Python
Python键鼠操作自动化库PyAutoGUI简介(小结)
May 17 Python
python开发一款翻译工具
Oct 10 Python
解决Pycharm 运行后没有输出的问题
Feb 05 Python
python利用while求100内的整数和方式
Nov 07 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
玩家交还《星际争霸》原始码光盘 暴雪报以厚礼
2017/05/05 星际争霸
php+ajax实现的点击浏览量加1
2015/04/16 PHP
Laravel 5使用Laravel Excel实现Excel/CSV文件导入导出的功能详解
2017/10/11 PHP
js的闭包的一个示例说明
2008/11/18 Javascript
利用jQuery的$.event.fix函数统一浏览器event事件处理
2009/12/21 Javascript
jquery parent和parents的区别分析
2013/10/02 Javascript
javascript日期验证之输入日期大于等于当前日期
2015/12/13 Javascript
Node.js本地文件操作之文件拷贝与目录遍历的方法
2016/02/16 Javascript
Bootstrap选项卡动态切换效果
2016/11/28 Javascript
Vue.js第四天学习笔记(组件)
2016/12/02 Javascript
js读取json文件片段中的数据实例
2017/03/09 Javascript
Node.js和Express简单入门介绍
2017/03/24 Javascript
微信小程序五星评分效果实现代码
2017/04/06 Javascript
import与export在node.js中的使用详解
2017/09/28 Javascript
js数据类型检测总结
2018/08/05 Javascript
微信小程序中使用自定义图标(阿里icon)的方法
2018/08/20 Javascript
ES6知识点整理之函数数组参数的默认值及其解构应用示例
2019/04/17 Javascript
vue 实现用户登录方式的切换功能
2020/04/14 Javascript
详解Vue2的diff算法
2021/01/06 Vue.js
[02:44]重置世界,颠覆未来——DOTA2 7.23版本震撼上线
2019/12/01 DOTA
python中global用法实例分析
2015/04/30 Python
在Python中操作字典之clear()方法的使用
2015/05/21 Python
在Python的Flask框架中构建Web表单的教程
2016/06/04 Python
python实现类之间的方法互相调用
2018/04/29 Python
pandas读取csv文件,分隔符参数sep的实例
2018/12/12 Python
python实现复制大量文件功能
2019/08/31 Python
Python hashlib加密模块常用方法解析
2019/12/18 Python
Python使用正则表达式实现爬虫数据抽取
2020/08/17 Python
俄罗斯茶和咖啡网上商店:Tea.ru
2021/01/26 全球购物
高级人员简历的自我评价分享
2013/11/03 职场文书
2014年大班保育员工作总结
2014/12/02 职场文书
使用python求解迷宫问题的三种实现方法
2022/03/17 Python
Elasticsearch 基本查询和组合查询
2022/04/19 Python
Mysql 数据库中的 redo log 和 binlog 写入策略
2022/04/26 MySQL
html中相对位置与绝对位置的具体使用
2022/05/15 HTML / CSS
分享几个实用的CSS代码块
2022/06/10 HTML / CSS