使用pytorch实现线性回归


Posted in Python onApril 11, 2021

线性回归都是包括以下几个步骤:定义模型、选择损失函数、选择优化函数、 训练数据、测试

import torch
import matplotlib.pyplot as plt
# 构建数据集
x_data= torch.Tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0]])
y_data= torch.Tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0]])
#定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear= torch.nn.Linear(1,1) #表示输入输出都只有一层,相当于前向传播中的函数模型,因为我们一般都不知道函数是什么形式的
 
    def forward(self, x):
        y_pred= self.linear(x)
        return y_pred
model= LinearModel()
# 使用均方误差作为损失函数
criterion= torch.nn.MSELoss(size_average= False)
#使用梯度下降作为优化SGD
# 从下面几种优化器的生成结果图像可以看出,SGD和ASGD效果最好,因为他们的图像收敛速度最快
optimizer= torch.optim.SGD(model.parameters(),lr=0.01)
# ASGD
# optimizer= torch.optim.ASGD(model.parameters(),lr=0.01)
# optimizer= torch.optim.Adagrad(model.parameters(), lr= 0.01)
# optimizer= torch.optim.RMSprop(model.parameters(), lr= 0.01)
# optimizer= torch.optim.Adamax(model.parameters(),lr= 0.01)
# 训练
epoch_list=[]
loss_list=[]
for epoch in range(100):
    y_pred= model(x_data)
    loss= criterion(y_pred, y_data)
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    print(epoch, loss.item())
 
    optimizer.zero_grad() #梯度归零
    loss.backward()  #反向传播
    optimizer.step() #更新参数
 
print("w= ", model.linear.weight.item())
print("b= ",model.linear.bias.item())
 
x_test= torch.Tensor([[7.0]])
y_test= model(x_test)
print("y_pred= ",y_test.data)
 
plt.plot(epoch_list, loss_list)
plt.xlabel("epoch")
plt.ylabel("loss_val")
plt.show()

使用SGD优化器图像:                                                      

使用pytorch实现线性回归

使用ASGD优化器图像:

使用pytorch实现线性回归

使用Adagrad优化器图像:                                                 

使用pytorch实现线性回归

使用Adamax优化器图像:

使用pytorch实现线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 多线程实现检测服务器在线情况
Nov 25 Python
Python中的变量和作用域详解
Jul 13 Python
Golang与python线程详解及简单实例
Apr 27 Python
查看django版本的方法分享
May 14 Python
Django应用程序入口WSGIHandler源码解析
Aug 05 Python
python爬虫 爬取超清壁纸代码实例
Aug 16 Python
python3 下载网络图片代码实例
Aug 27 Python
Python实现变声器功能(萝莉音御姐音)
Dec 05 Python
使用Python求解带约束的最优化问题详解
Feb 11 Python
使用 Python 遍历目录树的方法
Feb 29 Python
如何查看Django ORM执行的SQL语句的实现
Apr 20 Python
基于Python爬取51cto博客页面信息过程解析
Aug 25 Python
pytorch实现线性回归以及多元回归
python如何获取网络数据
Apr 11 #Python
Pytorch 使用tensor特定条件判断索引
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
You might like
PHP验证码函数代码(简单实用)
2013/09/29 PHP
php使用多个进程同时控制文件读写示例
2014/02/28 PHP
让codeigniter与swfupload整合的最佳解决方案
2014/06/12 PHP
php备份数据库类分享
2015/04/14 PHP
再谈PHP中单双引号的区别详解
2016/06/12 PHP
zen cart实现订单中增加paypal中预留电话的方法
2016/07/12 PHP
PHP的CURL方法curl_setopt()函数案例介绍(抓取网页,POST数据)
2016/12/14 PHP
PHPUnit测试私有属性和方法功能示例
2018/06/12 PHP
php使用redis的几种常见操作方式和用法示例
2020/02/20 PHP
JQuery 初体验(建议学习jquery)
2009/04/25 Javascript
javascript温习的一些笔记 基础常用知识小结
2011/06/22 Javascript
Array 重排序方法和操作方法的简单实例
2014/01/24 Javascript
iframe窗口高度自适应的又一个巧妙实现思路
2014/04/04 Javascript
javascript trim函数在IE下不能用的解决方法
2014/09/12 Javascript
用move.js库实现百叶窗特效
2017/02/08 Javascript
ES6新特性之Symbol类型用法分析
2017/03/31 Javascript
史上最全JavaScript常用的简写技巧(推荐)
2017/08/17 Javascript
Vue2.0仿饿了么webapp单页面应用详细步骤
2018/07/08 Javascript
Vue核心概念Getter的使用方法
2019/01/18 Javascript
一些可能会用到的Node.js面试题
2019/06/15 Javascript
JS开发自己的类库实例分析
2019/08/28 Javascript
[05:00]第二届DOTA2亚洲邀请赛主赛事第三天比赛集锦.mp4
2017/04/04 DOTA
[02:22]《新闻直播间》2017年08月14日
2017/08/15 DOTA
Python cookbook(数据结构与算法)从字典中提取子集的方法示例
2018/03/22 Python
python调试神器PySnooper的使用
2019/07/03 Python
python实现最大子序和(分治+动态规划)
2019/07/05 Python
django 模型中的计算字段实例
2020/05/19 Python
解决Python3.8运行tornado项目报NotImplementedError错误
2020/09/02 Python
荷兰网上买鞋:MooieSchoenen.nl
2017/09/12 全球购物
会计出纳岗位职责
2013/12/25 职场文书
给物业的表扬信
2014/01/21 职场文书
办公室岗位职责
2014/02/12 职场文书
总经理任命书
2014/03/29 职场文书
2015年社区纪检工作总结
2015/04/21 职场文书
2016毕业实习单位评语大全
2015/12/01 职场文书
CSS中em的正确打开方式详解
2021/04/08 HTML / CSS