使用 pytorch 创建神经网络拟合sin函数的实现


Posted in Python onFebruary 24, 2020

我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。

基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。

1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。

无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。

废话少说,直接上代码:

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch

# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)

# 神经网络主要结构,这里就是一个简单的线性结构

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net=nn.Sequential(
      nn.Linear(in_features=1,out_features=10),nn.ReLU(),
      nn.Linear(10,100),nn.ReLU(),
      nn.Linear(100,10),nn.ReLU(),
      nn.Linear(10,1)
    )

  def forward(self, input:torch.FloatTensor):
    return self.net(input)

net=Net()

# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()

# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
  loss=None
  for batch_x,batch_y in dataloader:
    y_predict=net(batch_x)
    loss=Loss(y_predict,batch_y)
    optim.zero_grad()
    loss.backward()
    optim.step()
  # 每100次 的时候打印一次日志
  if (epoch+1)%100==0:
    print("step: {0} , loss: {1}".format(epoch+1,loss.item()))

# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))

# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()

输出结果:

step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05

输出图像:

使用 pytorch 创建神经网络拟合sin函数的实现

到此这篇关于使用 pytorch 创建神经网络拟合sin函数的实现的文章就介绍到这了,更多相关pytorch 创建拟合sin函数内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现问号表达式(?)的方法
Nov 27 Python
python实现批量监控网站
Sep 09 Python
pycharm+django创建一个搜索网页实例代码
Jan 24 Python
python 循环读取txt文档 并转换成csv的方法
Oct 26 Python
Python3 Post登录并且保存cookie登录其他页面的方法
Dec 28 Python
Python SSL证书验证问题解决方案
Jan 13 Python
TensorFlow实现指数衰减学习率的方法
Feb 05 Python
Python3基本输入与输出操作实例分析
Feb 14 Python
python发qq消息轰炸虐狗好友思路详解(完整代码)
Feb 15 Python
Python实现在Windows平台修改文件属性
Mar 05 Python
python识别验证码的思路及解决方案
Sep 13 Python
python实现快速文件格式批量转换的方法
Oct 16 Python
sklearn+python:线性回归案例
Feb 24 #Python
深入理解Tensorflow中的masking和padding
Feb 24 #Python
K最近邻算法(KNN)---sklearn+python实现方式
Feb 24 #Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
Feb 24 #Python
Python enumerate内置库用法解析
Feb 24 #Python
Python模块/包/库安装的六种方法及区别
Feb 24 #Python
python之MSE、MAE、RMSE的使用
Feb 24 #Python
You might like
无数据库的详细域名查询程序PHP版(5)
2006/10/09 PHP
PHP全功能无变形图片裁剪操作类与用法示例
2017/01/10 PHP
PHP中通过getopt解析GNU C风格命令行选项
2019/11/18 PHP
java script编程起步(第三课)
2007/01/10 Javascript
JS array 数组详解
2009/03/22 Javascript
一些主流JS框架中DOMReady事件的实现小结
2011/02/12 Javascript
jquery实现多级下拉菜单的实例代码
2013/10/02 Javascript
防止jQuery ajax Load使用缓存的方法小结
2014/02/22 Javascript
使用JS取得焦点(focus)元素代码
2014/03/22 Javascript
JavaScript onkeydown事件入门实例(键盘某个按键被按下)
2014/10/17 Javascript
js中键盘事件实例简析
2015/01/10 Javascript
浅析jquery与checkbox的checked属性的问题
2016/04/27 Javascript
微信小程序实现移动端滑动分页效果(ajax)
2017/06/13 Javascript
微信小程序模拟cookie的实现
2018/06/20 Javascript
实例分析javascript中的异步
2020/06/02 Javascript
你不知道的 TypeScript 高级类型(小结)
2020/08/28 Javascript
pycharm 使用心得(三)Hello world!
2014/06/05 Python
rabbitmq(中间消息代理)在python中的使用详解
2017/12/14 Python
django初始化数据库的实例
2018/05/27 Python
python实现requests发送/上传多个文件的示例
2018/06/04 Python
Numpy截取指定范围内的数据方法
2018/11/14 Python
Python将string转换到float的实例方法
2019/07/29 Python
numpy中的meshgrid函数的使用
2019/07/31 Python
浅谈python已知元素,获取元素索引(numpy,pandas)
2019/11/26 Python
Python任务调度模块APScheduler使用
2020/04/15 Python
Python Django中间件使用原理及流程分析
2020/06/13 Python
HTML5本地存储localStorage、sessionStorage基本用法、遍历操作、异常处理等
2014/05/08 HTML / CSS
美国知名日用品连锁超市:Dollar General(多来店)
2017/01/14 全球购物
网上常见的一份Linux面试题(多项选择部分)
2015/02/07 面试题
物业管理计划书
2014/01/10 职场文书
读书小明星事迹材料
2014/05/03 职场文书
做一个有道德的人活动实施方案
2014/08/23 职场文书
先进集体事迹材料范文
2014/12/25 职场文书
2015年上半年信访工作总结
2015/03/30 职场文书
歌咏比赛口号大全
2015/12/25 职场文书
导游词之南迦巴瓦峰
2019/11/19 职场文书