基于pytorch的lstm参数使用详解


Posted in Python onJanuary 14, 2020

lstm(*input, **kwargs)

将多层长短时记忆(LSTM)神经网络应用于输入序列。

参数:

input_size:输入'x'中预期特性的数量

hidden_size:隐藏状态'h'中的特性数量

num_layers:循环层的数量。例如,设置' ' num_layers=2 ' '意味着将两个LSTM堆叠在一起,形成一个'堆叠的LSTM ',第二个LSTM接收第一个LSTM的输出并计算最终结果。默认值:1

bias:如果' False',则该层不使用偏置权重' b_ih '和' b_hh '。默认值:'True'

batch_first:如果' 'True ' ',则输入和输出张量作为(batch, seq, feature)提供。默认值: 'False'

dropout:如果非零,则在除最后一层外的每个LSTM层的输出上引入一个“dropout”层,相当于:attr:'dropout'。默认值:0

bidirectional:如果‘True',则成为双向LSTM。默认值:'False'

输入:input,(h_0, c_0)

**input**of shape (seq_len, batch, input_size):包含输入序列特征的张量。输入也可以是一个压缩的可变长度序列。

see:func:'torch.nn.utils.rnn.pack_padded_sequence' 或:func:'torch.nn.utils.rnn.pack_sequence' 的细节。

**h_0** of shape (num_layers * num_directions, batch, hidden_size):张量包含批处理中每个元素的初始隐藏状态。

如果RNN是双向的,num_directions应该是2,否则应该是1。

**c_0** of shape (num_layers * num_directions, batch, hidden_size):张量包含批处理中每个元素的初始单元格状态。

如果没有提供' (h_0, c_0) ',则**h_0**和**c_0**都默认为零。

输出:output,(h_n, c_n)

**output**of shape (seq_len, batch, num_directions * hidden_size) :包含LSTM最后一层输出特征' (h_t) '张量,

对于每个t. If a:class: 'torch.nn.utils.rnn.PackedSequence' 已经给出,输出也将是一个打包序列。

对于未打包的情况,可以使用'output.view(seq_len, batch, num_directions, hidden_size)',正向和反向分别为方向' 0 '和' 1 '。

同样,在包装的情况下,方向可以分开。

**h_n** of shape (num_layers * num_directions, batch, hidden_size):包含' t = seq_len '隐藏状态的张量。

与*output*类似, the layers可以使用以下命令分隔

h_n.view(num_layers, num_directions, batch, hidden_size) 对于'c_n'相似

**c_n** (num_layers * num_directions, batch, hidden_size):张量包含' t = seq_len '的单元状态

所有的权重和偏差都初始化自: 基于pytorch的lstm参数使用详解 where: 基于pytorch的lstm参数使用详解

include:: cudnn_persistent_rnn.rst
import torch
import torch.nn as nn
 
# 双向rnn例子
# rnn = nn.RNN(10, 20, 2)
# input = torch.randn(5, 3, 10)
# h0 = torch.randn(2, 3, 20)
# output, hn = rnn(input, h0)
# print(output.shape,hn.shape)
# torch.Size([5, 3, 20]) torch.Size([2, 3, 20])
 
# 双向lstm例子
rnn = nn.LSTM(10, 20, 2)   #(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10)  #(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20)    #(num_layers * num_directions, batch, hidden_size)
c0 = torch.randn(2, 3, 20)    #(num_layers * num_directions, batch, hidden_size)
# output:(seq_len, batch, num_directions * hidden_size)
# hn,cn(num_layers * num_directions, batch, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0)) 
 
print(output.shape,hn.shape,cn.shape)
>>>torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

以上这篇基于pytorch的lstm参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python开发windows GUI程序入门实例
Oct 23 Python
Python实现监控程序执行时间并将其写入日志的方法
Jun 30 Python
Python判断某个用户对某个文件的权限
Oct 13 Python
Python tkinter事件高级用法实例
Jan 31 Python
python清除函数占用的内存方法
Jun 25 Python
Python定时发送消息的脚本:每天跟你女朋友说晚安
Oct 21 Python
3分钟学会一个Python小技巧
Nov 23 Python
python循环定时中断执行某一段程序的实例
Jun 29 Python
django将网络中的图片,保存成model中的ImageField的实例
Aug 07 Python
Python搭建代理IP池实现接口设置与整体调度
Oct 27 Python
Python TKinter如何自动关闭主窗口
Feb 26 Python
Keras中的两种模型:Sequential和Model用法
Jun 27 Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
Python selenium 自动化脚本打包成一个exe文件(推荐)
Jan 14 #Python
pytorch+lstm实现的pos示例
Jan 14 #Python
Python中sorted()排序与字母大小写的问题
Jan 14 #Python
Pytorch实现LSTM和GRU示例
Jan 14 #Python
You might like
PHP配置心得包含MYSQL5乱码解决
2006/11/20 PHP
PHP session文件独占锁引起阻塞问题解决方法
2015/05/12 PHP
php获得文件夹下所有文件的递归算法的简单实例
2016/11/01 PHP
PHP面向对象程序设计模拟一般面向对象语言中的方法重载(overload)示例
2019/06/13 PHP
Laravel实现通过blade模板引擎渲染视图
2019/10/25 PHP
jquery ajax jsonp跨域调用实例代码
2013/12/11 Javascript
JavaScript数组常用操作技巧汇总
2014/11/17 Javascript
JavaScript实现大数的运算
2014/11/24 Javascript
JS HTML5拖拽上传图片预览
2016/07/18 Javascript
微信小程序实现页面跳转传值的方法
2017/10/12 Javascript
Vue-router路由判断页面未登录跳转到登录页面的实例
2017/10/26 Javascript
Vue组件的使用及个人理解与介绍
2019/02/09 Javascript
详解ES6中class的实现原理
2020/10/03 Javascript
移动端JS实现拖拽两种方法解析
2020/10/12 Javascript
Vue使用CDN引用项目组件,减少项目体积的步骤
2020/10/30 Javascript
[01:08:24]DOTA2-DPC中国联赛 正赛 RNG vs Phoenix BO3 第一场 2月5日
2021/03/11 DOTA
python解析xml文件实例分享
2013/12/04 Python
python获取beautifulphoto随机某图片代码实例
2013/12/18 Python
python删除文本中行数标签的方法
2018/05/31 Python
Python的iOS自动化打包实例代码
2018/11/22 Python
使用Python3+PyQT5+Pyserial 实现简单的串口工具方法
2019/02/13 Python
Python实现桌面翻译工具【新手必学】
2020/02/12 Python
python 中的paramiko模块简介及安装过程
2020/02/29 Python
python入门之井字棋小游戏
2020/03/05 Python
python中count函数知识点浅析
2020/12/17 Python
关联、聚合(Aggregation)以及组合(Composition)的区别
2012/02/29 面试题
运动会广播稿400字
2014/01/25 职场文书
工作年限证明模板
2014/11/01 职场文书
校园运动会广播稿
2015/08/19 职场文书
2015年度个人工作总结报告
2015/10/24 职场文书
少儿励志名言(80句)
2019/08/14 职场文书
怎样写好演讲稿题目?
2019/08/21 职场文书
golang DNS服务器的简单实现操作
2021/04/30 Golang
MySQL如何构建数据表索引
2021/05/13 MySQL
react antd实现动态增减表单
2021/06/03 Javascript
Nginx 502 bad gateway错误解决的九种方案及原因
2022/08/14 Servers