Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python os模块学习笔记
Jun 21 Python
最近Python有点火? 给你7个学习它的理由!
Jun 26 Python
OPENCV去除小连通区域,去除孔洞的实例讲解
Jun 21 Python
Pandas过滤dataframe中包含特定字符串的数据方法
Nov 07 Python
Python生成rsa密钥对操作示例
Apr 26 Python
在PyCharm中控制台输出日志分层级分颜色显示的方法
Jul 11 Python
对python中assert、isinstance的用法详解
Nov 27 Python
Python自动化测试笔试面试题精选
Mar 12 Python
使用Pycharm分段执行代码
Apr 15 Python
Cpython解释器中的GIL全局解释器锁
Nov 09 Python
python3定位并识别图片验证码实现自动登录功能
Jan 29 Python
详解python的异常捕获
Mar 03 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
php垃圾代码优化操作代码
2010/08/05 PHP
PHP连接SQLSERVER 注意事项(附dll文件下载)
2012/06/28 PHP
php调用shell的方法
2014/11/05 PHP
PHP中iconv函数转码时截断字符问题的解决方法
2015/01/21 PHP
PHP学习笔记之php文件操作
2016/06/03 PHP
php+redis实现商城秒杀功能
2020/11/19 PHP
跨浏览器通用、可重用的选项卡tab切换js代码
2011/09/20 Javascript
textarea焦点的用法实现获取焦点清空失去焦点提示效果
2014/05/19 Javascript
情人节单身的我是如何在敲完代码之后收到12束玫瑰的(javascript)
2015/08/21 Javascript
javascript实现C语言经典程序题
2015/11/29 Javascript
jQuery Select下拉框操作小结(推荐)
2016/07/22 Javascript
微信小程序 tabs选项卡效果的实现
2017/01/05 Javascript
jQuery 全选 全不选 事件绑定的实现代码
2017/01/23 Javascript
Iscrool下拉刷新功能实现方法(推荐)
2017/06/26 Javascript
jQuery实现QQ空间汉字转拼音功能示例
2017/07/10 jQuery
Vue中之nextTick函数源码分析详解
2017/10/17 Javascript
使用ECharts实现状态区间图
2018/10/25 Javascript
微信小程序错误this.setData报错及解决过程
2019/09/18 Javascript
Vuex的API文档说明详解
2020/02/05 Javascript
微信小程序拖拽排序列表的示例代码
2020/07/08 Javascript
[01:01:04]2018DOTA2亚洲邀请赛 4.5 淘汰赛 OpTic vs TNC 第一场
2018/04/06 DOTA
Python如何实现文本转语音
2016/08/08 Python
Python 实现判断图片格式并转换,将转换的图像存到生成的文件夹中
2020/01/13 Python
Python 分布式缓存之Reids数据类型操作详解
2020/06/24 Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
2020/06/28 Python
最简单的matplotlib安装教程(小白)
2020/07/28 Python
中国高端家电购物商城:顺电
2018/03/04 全球购物
PHP数据运算类型都有哪些
2013/11/05 面试题
Linux开机引导的步骤是什么
2014/02/26 面试题
成人毕业生自我鉴定
2013/10/18 职场文书
营销与策划专业毕业生求职信
2013/11/01 职场文书
自动化专业个人求职信范文
2013/11/29 职场文书
《圆明园的毁灭》教学反思
2014/02/28 职场文书
小学评语大全
2014/04/22 职场文书
关于antd tree 和父子组件之间的传值问题(react 总结)
2021/06/02 Javascript
分享mysql的current_timestamp小坑及解决
2021/11/27 MySQL