使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的Django框架中获取单个对象数据的简单方法
Jul 17 Python
利用python批量给云主机配置安全组的方法教程
Jun 21 Python
初学python的操作难点总结(新手必看篇)
Aug 03 Python
Python3导入自定义模块的三种方法详解
Apr 13 Python
Python button选取本地图片并显示的实例
Jun 13 Python
django+tornado实现实时查看远程日志的方法
Aug 12 Python
pytest中文文档之编写断言
Sep 12 Python
Python如何优雅获取本机IP方法
Nov 10 Python
Django通用类视图实现忘记密码重置密码功能示例
Dec 17 Python
Django框架之中间件MiddleWare的实现
Dec 30 Python
Python运行DLL文件的方法
Jan 17 Python
Pytorch 使用不同版本的cuda的方法步骤
Apr 02 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
smarty基础之拼接字符串的详解
2013/06/18 PHP
php通过strpos查找字符串出现位置的方法
2015/03/17 PHP
PHP设计模式之适配器模式(Adapter)原理与用法详解
2019/12/12 PHP
jQuery的实现原理的模拟代码 -2 数据部分
2010/08/01 Javascript
javascript复制对象使用说明
2011/06/28 Javascript
jquery 模板的应用示例
2013/11/12 Javascript
《JavaScript DOM 编程艺术》读书笔记之JavaScript 简史
2015/01/09 Javascript
javascript通过元素id和name直接取得元素的方法
2015/04/28 Javascript
JQuery中clone方法复制节点
2015/05/18 Javascript
javascript实现图片延迟加载方法汇总(三种方法)
2015/08/27 Javascript
浅谈jQuery animate easing的具体使用方法(推荐)
2016/06/17 Javascript
浅谈JS中json数据的处理
2016/06/30 Javascript
微信小程序之数据绑定原理解析
2019/08/14 Javascript
微信小程序整个页面的自动适应布局的实现
2020/07/12 Javascript
JS跨浏览器解析XML应用过程详解
2020/10/16 Javascript
[47:42]完美世界DOTA2联赛PWL S2 GXR vs Ink 第一场 11.19
2020/11/20 DOTA
Python中list初始化方法示例
2016/09/18 Python
python用模块zlib压缩与解压字符串和文件的方法
2016/12/16 Python
Python简单定义与使用二叉树示例
2018/05/11 Python
Python如何发布程序的详细教程
2018/10/09 Python
python实现屏保程序(适用于背单词)
2019/07/30 Python
python单向链表的基本实现与使用方法【定义、遍历、添加、删除、查找等】
2019/10/24 Python
python 微信好友特征数据分析及可视化
2020/01/07 Python
Python基于smtplib协议实现发送邮件
2020/06/03 Python
Troy-Bilt官网:草坪割草机、吹雪机、分蘖机等
2019/02/19 全球购物
乌克兰网上珠宝商店:GoldSoveren
2020/03/31 全球购物
自荐信的两点禁忌
2013/10/30 职场文书
个人批评与自我批评发言稿
2014/09/28 职场文书
学校党委干部个人对照检查材料思想汇报
2014/10/09 职场文书
优秀教师单行材料
2014/12/16 职场文书
2015年劳动部工作总结
2015/05/23 职场文书
2016党员三严三实心得体会
2016/01/15 职场文书
创业计划书之游泳馆
2019/09/16 职场文书
详解MySQL主从复制及读写分离
2021/05/07 MySQL
一文搞懂MySQL索引页结构
2022/02/28 MySQL
MySQL深分页问题解决思路
2022/12/24 MySQL