使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python3读取UTF-8文件及统计文件行数的方法
May 22 Python
Windows下实现Python2和Python3两个版共存的方法
Jun 12 Python
深入讲解Java编程中类的生命周期
Feb 05 Python
Python IDLE 错误:IDLE''s subprocess didn''t make connection 的解决方案
Feb 13 Python
基于ID3决策树算法的实现(Python版)
May 31 Python
Python实现类的创建与使用方法示例
Jul 25 Python
Django 实现购物车功能的示例代码
Oct 08 Python
python输入整条数据分割存入数组的方法
Nov 13 Python
python3正则提取字符串里的中文实例
Jan 31 Python
python利用跳板机ssh远程连接redis的方法
Feb 19 Python
python FTP批量下载/删除/上传实例
Dec 22 Python
python线程池如何使用
May 28 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
PHP与MYSQL中UTF8编码的中文排序实例
2014/10/21 PHP
php运行提示:Fatal error Allowed memory size内存不足的解决方法
2014/12/17 PHP
laravel5.0在linux下解决.htaccess无效和去除index.php的问题
2019/10/16 PHP
laravel 解决多库下的DB::transaction()事务失效问题
2019/10/21 PHP
HTML DOM的nodeType值介绍
2011/03/31 Javascript
javascript窗口宽高,鼠标位置,滚动高度(详细解析)
2013/11/18 Javascript
javascript二维数组转置实例
2015/01/22 Javascript
纯css实现窗户玻璃雨滴逼真效果
2015/08/23 Javascript
jquery常用函数与方法汇总
2015/09/01 Javascript
javascript实现日期时间动态显示示例代码
2015/09/08 Javascript
AngularJS入门教程之AngularJS指令
2016/04/18 Javascript
jQuery中on绑定事件后引发的事件冒泡问题如何解决
2016/05/25 Javascript
跨域请求的完美解决方法(JSONP, CORS)
2016/06/12 Javascript
解决bootstrap导航栏navbar在IE8上存在缺陷的方法
2016/07/01 Javascript
d3.js入门教程之数据绑定详解
2017/04/28 Javascript
Vue之mixin全局的用法详解
2018/08/22 Javascript
jquery判断滚动条距离顶部的距离方法
2018/09/05 jQuery
微信小程序显示倒计时功能示例【测试可用】
2018/12/03 Javascript
谈谈JavaScript中super(props)的重要性
2019/02/12 Javascript
微信小程序中使用echarts的实现方法
2019/04/24 Javascript
JavaScript实现图片放大镜效果
2019/06/27 Javascript
JS+DIV实现拖动效果
2020/02/11 Javascript
Python中常用的8种字符串操作方法
2019/05/06 Python
pytorch 模型的train模式与eval模式实例
2020/02/20 Python
python编程进阶之异常处理用法实例分析
2020/02/21 Python
python实现从尾到头打印单链表操作示例
2020/02/22 Python
解决pycharm中的run和debug失效无法点击运行
2020/06/09 Python
Python如何把字典写入到CSV文件的方法示例
2020/08/23 Python
如何用Python 加密文件
2020/09/10 Python
HTML5离线应用与客户端存储的实现
2018/05/03 HTML / CSS
广州喜创信息技术有限公司JAVA软件工程师笔试题
2012/10/17 面试题
大学校园毕业自我鉴定
2014/01/15 职场文书
梅花魂教学反思
2014/04/25 职场文书
python中24小时制转换为12小时制的方法
2021/06/18 Python
IDEA 2022 Translation 未知错误 翻译文档失败
2022/04/24 Java/Android
el-form每行显示两列底部按钮居中效果的实现
2022/08/05 HTML / CSS