Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python之Web框架Django项目搭建全过程
May 02 Python
Python日期的加减等操作的示例
Aug 15 Python
python生成n个元素的全组合方法
Nov 13 Python
用python标准库difflib比较两份文件的异同详解
Nov 16 Python
利用Python实现原创工具的Logo与Help
Dec 03 Python
OpenCV 边缘检测
Jul 10 Python
简单了解django文件下载方式
Feb 10 Python
django序列化时使用外键的真实值操作
Jul 15 Python
python中random.randint和random.randrange的区别详解
Sep 20 Python
python 装饰器的使用示例
Oct 10 Python
scrapy redis配置文件setting参数详解
Nov 18 Python
Python中的np.argmin()和np.argmax()函数用法
Jun 02 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
PHP 获取远程文件内容的函数代码
2010/03/24 PHP
解析php中的escape函数
2013/06/29 PHP
用php和jQuery来实现“顶”和“踩”的投票功能
2016/10/13 PHP
Laravel框架在本地虚拟机快速安装的方法详解
2018/06/11 PHP
php反射学习之依赖注入示例
2019/06/14 PHP
php中get_object_vars()在数组的实例用法
2021/02/22 PHP
Jquery 选中表格一列并对表格排序实现原理
2012/12/15 Javascript
用显卡加速,轻松把笔记本打造成取暖器的办法!
2013/04/17 Javascript
js获取 type=radio 值的方法
2014/05/09 Javascript
Javascript闭包用法实例分析
2015/01/23 Javascript
JavaScript获取DOM元素的11种方法总结
2015/04/25 Javascript
学习JavaScript图片预加载模块
2016/11/07 Javascript
Angular实现的简单定时器功能示例
2017/12/28 Javascript
Swiper 4.x 使用方法(移动端网站的内容触摸滑动)
2018/05/17 Javascript
vue组件name的作用小结
2018/05/23 Javascript
在vue中使用echarts(折线图的demo,markline用法)
2020/07/20 Javascript
Vue 修改网站图标的方法
2020/12/31 Vue.js
[01:19]2014DOTA2国际邀请赛 采访TITAN战队ohaiyo 能赢DK很幸运
2014/07/12 DOTA
Python3搜索及替换文件中文本的方法
2015/05/22 Python
Python中绑定与未绑定的类方法用法分析
2016/04/29 Python
python模块之paramiko实例代码
2018/01/31 Python
Python向Excel中插入图片的简单实现方法
2018/04/24 Python
在Python中将函数作为另一个函数的参数传入并调用的方法
2019/01/22 Python
详解使用python绘制混淆矩阵(confusion_matrix)
2019/07/14 Python
弄清Pytorch显存的分配机制
2020/12/10 Python
HTML5超文本标记语言的实现方法
2020/09/24 HTML / CSS
夏威夷咖啡公司:Hawaii Coffee Company
2019/09/19 全球购物
就业推荐自我鉴定
2013/10/06 职场文书
银行实习自我鉴定
2013/10/12 职场文书
接受捐赠答谢词
2014/01/27 职场文书
小学国庆节活动方案
2014/02/11 职场文书
个人委托书范本
2014/09/13 职场文书
2014年学校党建工作汇报材料
2014/11/02 职场文书
2016母亲节感恩话语
2015/12/09 职场文书
导游词之西安骊山
2019/12/20 职场文书
css背景和边框标签实例详解
2021/05/21 HTML / CSS