Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python二叉树遍历的实现方法
Nov 21 Python
python实现批量获取指定文件夹下的所有文件的厂商信息
Sep 28 Python
python实现用户登陆邮件通知的方法
Jul 09 Python
Python 实现简单的电话本功能
Aug 09 Python
使用Python脚本和ADB命令实现卸载App
Feb 10 Python
Python 实现购物商城,含有用户入口和商家入口的示例
Sep 15 Python
Python基于jieba库进行简单分词及词云功能实现方法
Jun 16 Python
Python button选取本地图片并显示的实例
Jun 13 Python
python之列表推导式的用法
Nov 29 Python
Python利用Scrapy框架爬取豆瓣电影示例
Jan 17 Python
Python如何给函数库增加日志功能
Aug 04 Python
Pytorch 如何实现常用正则化
May 27 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
PHP实现发送邮件的方法(基于简单邮件发送类)
2015/12/17 PHP
Zend Framework分发器用法示例
2016/12/11 PHP
php回调函数处理数组操作示例
2020/04/13 PHP
浏览器无法运行JAVA脚本的解决方法
2008/01/09 Javascript
各种效果的jquery ui(接口)介绍
2008/09/17 Javascript
支持ie与FireFox的剪切板操作代码
2009/09/28 Javascript
自写的一个jQuery圆角插件
2010/10/26 Javascript
Javascript实现DIV滚动自动滚动到底部的代码
2012/03/01 Javascript
js控制web打印(局部打印)方法整理
2013/05/29 Javascript
判断一个变量是数组Array类型的方法
2013/09/16 Javascript
js与jquery获取父级元素,子级元素,兄弟元素的实现方法
2014/01/09 Javascript
使用jQuery将多条数据插入模态框的实现代码
2014/10/08 Javascript
angularjs指令中的compile与link函数详解
2014/12/06 Javascript
JQuery实现级联下拉框效果实例讲解
2015/09/17 Javascript
图解js图片轮播效果
2015/12/20 Javascript
基于javascript实现窗口抖动效果
2016/01/03 Javascript
实例详解jQuery结合GridView控件的使用方法
2016/01/04 Javascript
JavaScript的React框架中的JSX语法学习入门教程
2016/03/05 Javascript
JavaScript使用FileReader实现图片上传预览效果
2020/03/27 Javascript
实例讲解Vue.js中router传参
2018/04/22 Javascript
使用layui 渲染table数据表格的实例代码
2018/08/19 Javascript
微信小程序带动画弹窗组件使用方法详解
2018/11/27 Javascript
layer弹窗在键盘按回车将反复刷新的实现方法
2019/09/25 Javascript
python 队列详解及实例代码
2016/10/18 Python
基于python 二维数组及画图的实例详解
2018/04/03 Python
pyinstaller参数介绍以及总结详解
2019/07/12 Python
python数据预处理之数据标准化的几种处理方式
2019/07/17 Python
python使用beautifulsoup4爬取酷狗音乐代码实例
2019/12/04 Python
Win10里python3创建虚拟环境的步骤
2020/01/31 Python
Python txt文件如何转换成字典
2020/11/03 Python
python爬虫搭配起Bilibili唧唧的流程分析
2020/12/01 Python
Lookfantastic希腊官网:英国知名美妆购物网站
2018/09/15 全球购物
King Apparel官网:英国街头服饰品牌
2019/09/05 全球购物
数控专业应届生求职信
2013/11/27 职场文书
2015年高校就业工作总结
2015/05/04 职场文书
2015年中秋晚会主持稿
2015/07/30 职场文书