pytorch实现线性回归以及多元回归


Posted in Python onApril 11, 2021

本文实例为大家分享了pytorch实现线性回归以及多元回归的具体代码,供大家参考,具体内容如下

最近在学习pytorch,现在把学习的代码放在这里,下面是github链接

直接附上github代码

# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
 
# 首先我们需要将array转化成tensor,因为pytorch处理的单元是Tensor
 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
 
 
# def a simple network
 
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1, 1)  # input and output is 2_dimension
    def forward(self, x):
        out = self.linear(x)
        return out
 
 
if torch.cuda.is_available():
    model = LinearRegression().cuda()
    #model = model.cuda()
else:
    model = LinearRegression()
    #model = model.cuda()
 
# 定义loss function 和 optimize func
criterion = nn.MSELoss()   # 均方误差作为优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)
num_epochs = 30000
for epoch in range(num_epochs):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)
 
    # forward
    out = model(inputs)
    loss = criterion(out,outputs)
 
    # backword
    optimizer.zero_grad()  # 每次做反向传播之前都要进行归零梯度。不然梯度会累加在一起,造成不收敛的结果
    loss.backward()
    optimizer.step()
 
    if (epoch +1)%20==0:
        print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1,num_epochs,loss.data))
 
 
model.eval()  # 将模型变成测试模式
predict = model(Variable(x_train).cuda())
predict = predict.data.cpu().numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label = 'original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting line')
plt.show()

结果如图所示:

pytorch实现线性回归以及多元回归

多元回归:

# _*_encoding=utf-8_*_
# pytorch 里面最基本的操作对象是Tensor,pytorch 的tensor可以和numpy的ndarray相互转化。
# 实现一个线性回归
# 所有的层结构和损失函数都来自于 torch.nn
# torch.optim 是一个实现各种优化算法的包,调用的时候必须是需要优化的参数传入,这些参数都必须是Variable
 
 
# 实现 y = b + w1 *x + w2 *x**2 +w3*x**3
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nn
 
 
# pre_processing
def make_feature(x):
    x = x.unsqueeze(1)   # unsquenze 是为了添加维度1的,0表示第一维度,1表示第二维度,将tensor大小由3变为(3,1)
    return torch.cat([x ** i for i in range(1, 4)], 1)
 
# 定义好真实的数据
 
 
def f(x):
    W_output = torch.Tensor([0.5, 3, 2.4]).unsqueeze(1)
    b_output = torch.Tensor([0.9])
    return x.mm(W_output)+b_output[0]  # 外积,矩阵乘法
 
 
# 批量处理数据
def get_batch(batch_size =32):
 
    random = torch.randn(batch_size)
    x = make_feature(random)
    y = f(x)
    if torch.cuda.is_available():
 
        return Variable(x).cuda(),Variable(y).cuda()
    else:
        return Variable(x),Variable(y)
 
 
 
# def model
class poly_model(nn.Module):
    def __init__(self):
        super(poly_model,self).__init__()
        self.poly = nn.Linear(3,1)
    def forward(self,input):
        output = self.poly(input)
        return output
 
if torch.cuda.is_available():
    print("sdf")
    model = poly_model().cuda()
else:
    model = poly_model()
 
 
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
 
epoch = 0
while True:
    batch_x, batch_y = get_batch()
    #print(batch_x)
    output = model(batch_x)
    loss = criterion(output,batch_y)
    print_loss = loss.data
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch = epoch +1
    if print_loss < 1e-3:
        print(print_loss)
        break
 
model.eval()
print("Epoch = {}".format(epoch))
 
batch_x, batch_y = get_batch()
predict = model(batch_x)
a = predict - batch_y
y = torch.sum(a)
print('y = ',y)
predict = predict.data.cpu().numpy()
plt.plot(batch_x.cpu().numpy(),batch_y.cpu().numpy(),'ro',label = 'Original data')
plt.plot(batch_x.cpu().numpy(),predict,'b', ls='--',label = 'Fitting line')
plt.show()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python写asp详细讲解
Dec 16 Python
python使用socket远程连接错误处理方法
Apr 29 Python
Jupyter notebook远程访问服务器的方法
May 24 Python
python实现泊松图像融合
Jul 26 Python
python 与服务器的共享文件夹交互方法
Dec 27 Python
Python @property使用方法解析
Sep 17 Python
详解Python可视化神器Yellowbrick使用
Nov 11 Python
Python3操作MongoDB增册改查等方法详解
Feb 10 Python
python为什么会环境变量设置不成功
Jun 23 Python
Python中的None与 NULL(即空字符)的区别详解
Sep 24 Python
Django利用elasticsearch(搜索引擎)实现搜索功能
Nov 26 Python
Python中正则表达式对单个字符,多个字符和匹配边界等使用
Jan 27 Python
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
You might like
8个出色的WordPress SEO插件收集
2011/02/26 PHP
PHP实现的sqlite数据库连接类
2014/12/12 PHP
Windows2003下php5.4安装配置教程(IIS)
2016/06/30 PHP
php ZipArchive实现多文件打包下载实例
2019/10/31 PHP
用javascript实现点击链接弹出&quot;图片另存为&quot;而不是直接打开
2007/08/15 Javascript
深入理解jQuery中live与bind方法的区别
2013/12/18 Javascript
javascript判断office版本示例
2014/04/11 Javascript
jQuery获取动态生成的元素示例
2014/06/15 Javascript
IE8下Jquery获取select选中的值post到后台报错问题
2014/07/02 Javascript
JavaScript制作淘宝星级评分效果的思路
2020/06/23 Javascript
Bootstrap CSS组件之大屏幕展播
2016/12/17 Javascript
js实现选项卡内容切换以及折叠和展开效果【推荐】
2017/01/08 Javascript
nuxt+axios解决前后端分离SSR的示例代码
2017/10/24 Javascript
vue主动刷新页面及列表数据删除后的刷新实例
2018/09/16 Javascript
基于Vue 撸一个指令实现拖拽功能
2019/10/09 Javascript
jQuery AJAX应用实例总结
2020/05/19 jQuery
如何使用jQuery操作Cookies方法解析
2020/09/08 jQuery
Python遍历目录的4种方法实例介绍
2015/04/13 Python
Python Pandas找到缺失值的位置方法
2018/04/12 Python
Python中.join()和os.path.join()两个函数的用法详解
2018/06/11 Python
Python机器学习k-近邻算法(K Nearest Neighbor)实例详解
2018/06/25 Python
python按键按住不放持续响应的实例代码
2019/07/17 Python
python支付宝支付示例详解
2019/08/22 Python
美国网上鞋城:Shoeline.com
2016/11/17 全球购物
J.Crew官网:美国知名休闲服装品牌
2017/05/19 全球购物
英国亚马逊官方网站:Amazon.co.uk
2019/08/09 全球购物
联想阿根廷官方网站:Lenovo Argentina
2019/10/14 全球购物
护士自我鉴定
2013/10/23 职场文书
国贸专业的职业规划范文
2014/01/23 职场文书
高中生职业规划范文
2014/03/09 职场文书
2014年幼儿园保育工作总结
2014/12/02 职场文书
2014年公路养护工作总结
2014/12/04 职场文书
刑事上诉状(量刑过重)
2015/05/23 职场文书
SQL IDENTITY_INSERT作用案例详解
2021/08/23 MySQL
JS高级程序设计之class继承重点详解
2022/07/07 Javascript
SQL中去除重复数据的几种方法汇总(窗口函数对数据去重)
2023/05/08 MySQL