使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python2.x利用commands模块执行Linux shell命令
Mar 11 Python
python基于phantomjs实现导入图片
May 13 Python
Python 基础知识之字符串处理
Jan 06 Python
利用python批量给云主机配置安全组的方法教程
Jun 21 Python
python 实现A*算法的示例代码
Aug 13 Python
关于python之字典的嵌套,递归调用方法
Jan 21 Python
Python正则表达式实现简易计算器功能示例
May 07 Python
Python 一键制作微信好友图片墙的方法
May 16 Python
mac系统下Redis安装和使用步骤详解
Jul 09 Python
使用Python将字符串转换为格式化的日期时间字符串
Sep 01 Python
python 多维高斯分布数据生成方式
Dec 09 Python
Python3实现监控新型冠状病毒肺炎疫情的示例代码
Feb 13 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
PHP静态新闻列表自动生成代码
2007/06/14 PHP
php中一个完整表单处理实现代码
2011/11/10 PHP
php+highchats生成动态统计图
2014/05/21 PHP
人脸识别测颜值、测脸龄、测相似度微信接口
2016/04/07 PHP
JavaScript 学习技巧
2010/02/17 Javascript
JavaScript高级程序设计 错误处理与调试学习笔记
2011/09/10 Javascript
返回上一页并自动刷新的JavaScript代码
2014/02/19 Javascript
如何防止回车(enter)键提交表单
2014/05/11 Javascript
JQuery对表单元素的基本操作使用总结
2014/07/18 Javascript
javascript面向对象之this关键词用法分析
2015/01/13 Javascript
js中for in语句的用法讲解
2015/04/24 Javascript
javascript获取当前的时间戳的方法汇总
2015/07/26 Javascript
JS实现漂亮的淡蓝色滑动门效果代码
2015/09/23 Javascript
AngularJs实现分页功能不带省略号的代码
2016/05/30 Javascript
js将table的每个td的内容自动赋值给其title属性的方法
2016/10/13 Javascript
jQuery实现select模糊查询(反射机制)
2017/01/14 Javascript
HTML的select控件美化
2017/03/27 Javascript
浅谈Vue.js中的v-on(事件处理)
2017/09/05 Javascript
使用webpack/gulp构建TypeScript项目的方法示例
2019/12/18 Javascript
vue等两个接口都返回结果再执行下一步的实例
2020/09/08 Javascript
浅谈nuxtjs校验登录中间件和混入(mixin)
2020/11/06 Javascript
[09:22]2014DOTA2西雅图国际邀请赛 主赛事第二日TOPPLAY
2014/07/21 DOTA
[02:40]2018年度DOTA2最佳新人-完美盛典
2018/12/16 DOTA
PyQt5主窗口动态加载Widget实例代码
2018/02/07 Python
Python3+Pycharm+PyQt5环境搭建步骤图文详解
2019/05/29 Python
python groupby 函数 as_index详解
2019/12/16 Python
Python hmac模块使用实例解析
2019/12/24 Python
Python生成六万个随机,唯一的8位数字和数字组成的随机字符串实例
2020/03/03 Python
解决django migrate报错ORA-02000: missing ALWAYS keyword
2020/07/02 Python
matplotlib相关系统目录获取方式小结
2021/02/03 Python
HTML5 textarea高度自适应的两种方案
2020/04/08 HTML / CSS
Html5移动端div固定到底部实现底部导航条的几种方式
2021/03/09 HTML / CSS
Invicta手表官方商店:百年制表历史的瑞士腕表品牌
2019/09/26 全球购物
自我评价如何写好?
2014/01/05 职场文书
好员工观后感
2015/06/17 职场文书
MySQL 如何分析查询性能
2021/05/12 MySQL