浅谈Tensorflow 动态双向RNN的输出问题


Posted in Python onJanuary 20, 2020

tf.nn.bidirectional_dynamic_rnn()

函数:

def bidirectional_dynamic_rnn(
  cell_fw, # 前向RNN
  cell_bw, # 后向RNN
  inputs, # 输入
  sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  initial_state_fw=None, # 前向的初始化状态(可选)
  initial_state_bw=None, # 后向的初始化状态(可选)
  dtype=None, # 初始化和输出的数据类型(可选)
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  # 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
  # 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
  scope=None
)

其中,

outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设

time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。

output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。

output_state_fw和output_state_bw的类型为LSTMStateTuple。

LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。

返回值:

元组:(outputs, output_states)

这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态

def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
  # rnn_size: rnn隐层节点数量
  # sequence_length: 数据的序列长度
  # num_layers:堆叠的rnn cell数量
  # rnn_inputs: 输入tensor
  # keep_prob:
  '''Create the encoding layer'''
  for layer in range(num_layers):
    with tf.variable_scope('encode_{}'.format(layer)):
      cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
 
      cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
      cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
 
      enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
                                  rnn_inputs,sequence_length,dtype=tf.float32)
 
  # join outputs since we are using a bidirectional RNN
  enc_output = tf.concat(enc_output,2) 
  return enc_output,enc_state

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn的返回值有两个:outputs和state

为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。

例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。

outputs. outputs是一个tensor

如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)

如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]

state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

这里有关于LSTM的结构问题:

浅谈Tensorflow 动态双向RNN的输出问题

以上这篇浅谈Tensorflow 动态双向RNN的输出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python两种遍历字典(dict)的方法比较
May 29 Python
Numpy数组的保存与读取方法
Apr 04 Python
浅谈Python的list中的选取范围
Nov 12 Python
由Python编写的MySQL管理工具代码实例
Apr 09 Python
python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法
Jun 10 Python
Django框架静态文件使用/中间件/禁用ip功能实例详解
Jul 22 Python
Python实现Singleton模式的方式详解
Aug 08 Python
python matplotlib拟合直线的实现
Nov 19 Python
Pyqt5自适应布局实例
Dec 13 Python
Python3 元组tuple入门基础
Feb 09 Python
jupyter notebook tensorflow打印device信息实例
Apr 20 Python
Python类绑定方法及非绑定方法实例解析
Oct 09 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
You might like
PHP MYSQL乱码问题,使用SET NAMES utf8校正
2009/11/30 PHP
php中url传递中文字符,特殊危险字符的解决方法
2013/08/17 PHP
隐性调用php程序的方法
2015/06/13 PHP
PHP的mysqli_stmt_init()函数讲解
2019/01/24 PHP
node.js中的fs.readSync方法使用说明
2014/12/17 Javascript
jQuery实现冻结表头的方法
2015/03/09 Javascript
JS使用ajax从xml文件动态获取数据显示的方法
2015/03/24 Javascript
jQuery实现的动态伸缩导航菜单实例
2015/05/07 Javascript
jquery事件的ready()方法使用详解
2015/11/11 Javascript
js+html5操作sqlite数据库的方法
2016/02/02 Javascript
jQuery on()方法绑定动态元素的点击事件实例代码浅析
2016/06/16 Javascript
浅谈jquery页面初始化的4种方式
2016/11/27 Javascript
利用JavaScript在网页实现八数码启发式A*算法动画效果
2017/04/16 Javascript
react系列从零开始_简单谈谈react
2017/07/06 Javascript
详解Vue一个案例引发「内容分发slot」的最全总结
2018/12/02 Javascript
[51:29]完美世界DOTA2联赛循环赛 Matador vs Forest BO2第一场 11.05
2020/11/05 DOTA
跨平台python异步回调机制实现和使用方法
2013/11/26 Python
python网络编程学习笔记(三):socket网络服务器
2014/06/09 Python
Python中的迭代器与生成器高级用法解析
2016/06/28 Python
Django中url的反向查询的方法
2018/03/14 Python
使用python itchat包爬取微信好友头像形成矩形头像集的方法
2019/02/21 Python
详解python多线程之间的同步(一)
2019/04/03 Python
Python使用微信接入图灵机器人过程解析
2019/11/04 Python
django框架中间件原理与用法详解
2019/12/10 Python
Python爬虫之Selenium警告框(弹窗)处理
2020/12/04 Python
Unix如何添加新的用户
2014/08/20 面试题
护士毕业自我鉴定
2014/02/07 职场文书
环保建议书300字
2014/05/14 职场文书
基层党员公开承诺书
2014/05/29 职场文书
对照检查剖析材料
2014/09/30 职场文书
党员批评与自我批评
2014/10/15 职场文书
2014年物业管理工作总结
2014/11/21 职场文书
财政局个人年终总结
2015/03/03 职场文书
2016年大学生寒假社会实践心得体会
2015/10/09 职场文书
大学生心理健康教育心得体会
2016/01/12 职场文书
诉讼和解协议书
2016/03/23 职场文书