关于tf.nn.dynamic_rnn返回值详解


Posted in Python onJanuary 20, 2020

函数原型

tf.nn.dynamic_rnn(
  cell,
  inputs,
  sequence_length=None,
  initial_state=None,
  dtype=None,
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  scope=None
)

实例讲解:

import tensorflow as tf
import numpy as np
 
n_steps = 2
n_inputs = 3
n_neurons = 5
 
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
 
seq_length = tf.placeholder(tf.int32, [None])
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32,
                  sequence_length=seq_length)
 
init = tf.global_variables_initializer()
 
X_batch = np.array([
    # step 0   step 1
    [[0, 1, 2], [9, 8, 7]], # instance 1
    [[3, 4, 5], [0, 0, 0]], # instance 2 (padded with zero vectors)
    [[6, 7, 8], [6, 5, 4]], # instance 3
    [[9, 0, 1], [3, 2, 1]], # instance 4
  ])
seq_length_batch = np.array([2, 1, 2, 2])
 
with tf.Session() as sess:
  init.run()
  outputs_val, states_val = sess.run(
    [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
  print("outputs_val.shape:", outputs_val.shape, "states_val.shape:", states_val.shape)
  print("outputs_val:", outputs_val, "states_val:", states_val)

log info:

outputs_val.shape: (4, 2, 5) states_val.shape: (4, 5)
outputs_val: 
[[[ 0.53073734 -0.61281306 -0.5437517  0.7320347 -0.6109526 ]
 [ 0.99996936 0.99990636 -0.9867181  0.99726075 -0.99999976]]
 
 [[ 0.9931584  0.5877845 -0.9100412  0.988892  -0.9982337 ]
 [ 0.     0.     0.     0.     0.    ]]
 
 [[ 0.99992317 0.96815354 -0.985101  0.9995968 -0.9999936 ]
 [ 0.99948144 0.9998127 -0.57493806 0.91015154 -0.99998355]]
 
 [[ 0.99999255 0.9998929  0.26732785 0.36024097 -0.99991137]
 [ 0.98875254 0.9922327  0.6505734  0.4732064 -0.9957567 ]]] 
states_val:
 [[ 0.99996936 0.99990636 -0.9867181  0.99726075 -0.99999976]
 [ 0.9931584  0.5877845 -0.9100412  0.988892  -0.9982337 ]
 [ 0.99948144 0.9998127 -0.57493806 0.91015154 -0.99998355]
 [ 0.98875254 0.9922327  0.6505734  0.4732064 -0.9957567 ]]

首先输入X是一个 [batch_size,step,input_size] = [4,2,3] 的tensor,注意我们这里调用的是BasicRNNCell,只有一层循环网络,outputs是最后一层每个step的输出,它的结构是[batch_size,step,n_neurons] = [4,2,5],states是每一层的最后那个step的输出,由于本例中,我们的循环网络只有一个隐藏层,所以它就代表这一层的最后那个step的输出,因此它和step的大小是没有关系的,我们的X有4个样本组成,输出神经元大小n_neurons是5,因此states的结构就是[batch_size,n_neurons] = [4,5],最后我们观察数据,states的每条数据正好就是outputs的最后一个step的输出。

下面我们继续讲解多个隐藏层的情况,这里是三个隐藏层,注意我们这里仍然是调用BasicRNNCell

import tensorflow as tf
import numpy as np
 
n_steps = 2
n_inputs = 3
n_neurons = 5
n_layers = 3
 
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])
 
layers = [tf.contrib.rnn.BasicRNNCell(num_units=n_neurons,
                   activation=tf.nn.relu)
     for layer in range(n_layers)]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32, sequence_length=seq_length)
 
init = tf.global_variables_initializer()
 
X_batch = np.array([
    # step 0   step 1
    [[0, 1, 2], [9, 8, 7]], # instance 1
    [[3, 4, 5], [0, 0, 0]], # instance 2 (padded with zero vectors)
    [[6, 7, 8], [6, 5, 4]], # instance 3
    [[9, 0, 1], [3, 2, 1]], # instance 4
  ])
 
seq_length_batch = np.array([2, 1, 2, 2])
 
with tf.Session() as sess:
  init.run()
  outputs_val, states_val = sess.run(
    [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
  print("outputs_val.shape:", outputs, "states_val.shape:", states)
  print("outputs_val:", outputs_val, "states_val:", states_val)

log info:

outputs_val.shape: 
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32) 
 
states_val.shape: 
(<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>, 
 <tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>, 
 <tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>)
 
outputs_val:
 [[[0.     0.     0.     0.     0.    ]
 [0.     0.18740742 0.     0.2997518 0.    ]]
 
 [[0.     0.07222144 0.     0.11551574 0.    ]
 [0.     0.     0.     0.     0.    ]]
 
 [[0.     0.13463384 0.     0.21534224 0.    ]
 [0.03702604 0.18443246 0.     0.34539366 0.    ]]
 
 [[0.     0.54511094 0.     0.8718864 0.    ]
 [0.5382122 0.     0.04396425 0.4040263 0.    ]]] 
 
states_val:
 (array([[0.    , 0.83723307, 0.    , 0.    , 2.8518028 ],
    [0.    , 0.1996038 , 0.    , 0.    , 1.5456247 ],
    [0.    , 1.1372368 , 0.    , 0.    , 0.832613 ],
    [0.    , 0.7904129 , 2.4675028 , 0.    , 0.36980057]],
   dtype=float32), 
 array([[0.6524607 , 0.    , 0.    , 0.    , 0.    ],
    [0.25143963, 0.    , 0.    , 0.    , 0.    ],
    [0.5010576 , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.3166597 , 0.4545995 , 0.    , 0.    ]],
   dtype=float32), 
 array([[0.    , 0.18740742, 0.    , 0.2997518 , 0.    ],
    [0.    , 0.07222144, 0.    , 0.11551574, 0.    ],
    [0.03702604, 0.18443246, 0.    , 0.34539366, 0.    ],
    [0.5382122 , 0.    , 0.04396425, 0.4040263 , 0.    ]],
   dtype=float32))

我们说过,outputs是最后一层的输出,即 [batch_size,step,n_neurons] = [4,2,5]

states是每一层的最后一个step的输出,即三个结构为 [batch_size,n_neurons] = [4,5] 的tensor

继续观察数据,states中的最后一个array,正好是outputs的最后那个step的输出

下面我们继续讲当由BasicLSTMCell构造单元工厂的时候,只讲多层的情况,我们只需要将上面的BasicRNNCell替换成BasicLSTMCell就行了,打印信息如下:

outputs_val.shape: 
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32) 
 
states_val.shape:
(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>), 
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(?, 5) dtype=float32>), 
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(?, 5) dtype=float32>))
 
outputs_val: 
[[[1.2949290e-04 0.0000000e+00 2.7623639e-04 0.0000000e+00 0.0000000e+00]
 [9.4675866e-05 0.0000000e+00 2.0214770e-04 0.0000000e+00 0.0000000e+00]]
 
 [[4.3100454e-06 4.2123037e-07 1.4312843e-06 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
 
 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
 
 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]] 
 
states_val: 
(LSTMStateTuple(
c=array([[0.    , 0.    , 0.04676079, 0.04284539, 0.    ],
    [0.    , 0.    , 0.0115245 , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ]],
   dtype=float32), 
h=array([[0.    , 0.    , 0.00035096, 0.04284406, 0.    ],
    [0.    , 0.    , 0.00142574, 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ]],
   dtype=float32)), 
LSTMStateTuple(
c=array([[0.0000000e+00, 1.0477135e-02, 4.9871090e-03, 8.2785974e-04,
    0.0000000e+00],
    [0.0000000e+00, 2.3306280e-04, 0.0000000e+00, 9.9445322e-05,
    5.9535629e-05],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32), 
h=array([[0.00000000e+00, 5.23016974e-03, 2.47756205e-03, 4.11730434e-04,
    0.00000000e+00],
    [0.00000000e+00, 1.16522635e-04, 0.00000000e+00, 4.97301044e-05,
    2.97713632e-05],
    [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
    [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00]], dtype=float32)), 
LSTMStateTuple(
c=array([[1.8937115e-04, 0.0000000e+00, 4.0442235e-04, 0.0000000e+00,
    0.0000000e+00],
    [8.6200516e-06, 8.4243663e-07, 2.8625946e-06, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32), 
h=array([[9.4675866e-05, 0.0000000e+00, 2.0214770e-04, 0.0000000e+00,
    0.0000000e+00],
    [4.3100454e-06, 4.2123037e-07, 1.4312843e-06, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32)))

我们先看看LSTM单元的结构

关于tf.nn.dynamic_rnn返回值详解

如果您不查看框内的内容,LSTM单元看起来与常规单元格完全相同,除了它的状态分为两个向量:h(t)和c(t)。你可以将h(t)视为短期状态,将c(t)视为长期状态。

因此我们的states包含三个LSTMStateTuple,每一个表示每一层的最后一个step的输出,这个输出有两个信息,一个是h表示短期记忆信息,一个是c表示长期记忆信息。维度都是[batch_size,n_neurons] = [4,5],states的最后一个LSTMStateTuple中的h就是outputs的最后一个step的输出

以上这篇关于tf.nn.dynamic_rnn返回值详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python strip lstrip rstrip使用方法
Sep 06 Python
Python Socket编程入门教程
Jul 11 Python
python中使用百度音乐搜索的api下载指定歌曲的lrc歌词
Jul 18 Python
Python实现将DOC文档转换为PDF的方法
Jul 25 Python
Python制作词云的方法
Jan 03 Python
1分钟快速生成用于网页内容提取的xslt
Feb 23 Python
python中的json总结
Oct 11 Python
Python3中在Anaconda环境下安装basemap包
Oct 21 Python
Python实现微信机器人的方法
Sep 06 Python
python3跳出一个循环的实例操作
Aug 18 Python
Pycharm github配置实现过程图解
Oct 13 Python
python3通过subprocess模块调用脚本并和脚本交互的操作
Dec 05 Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
You might like
php+js实现的拖动滑块验证码验证表单操作示例【附源码下载】
2020/05/27 PHP
JavaScript 原型继承
2011/12/26 Javascript
EXTjs4.0的store的findRecord的BUG演示代码
2013/06/08 Javascript
js实现俄罗斯方块小游戏分享
2014/01/31 Javascript
javascript日期格式化示例分享
2014/03/05 Javascript
JS 新增Cookie 取cookie值 删除cookie 举例详解
2014/10/10 Javascript
Javascript实现div的toggle效果实例分析
2015/06/09 Javascript
easyui tree带checkbox实现单选的简单实例
2016/11/07 Javascript
BootStrap实现轮播图效果(收藏)
2016/12/30 Javascript
js实现模糊匹配功能
2017/02/15 Javascript
AngularJS select设置默认值的实现方法
2017/08/25 Javascript
浅谈Vue的加载顺序探讨
2017/10/25 Javascript
layui框架中layer父子页面交互的方法分析
2017/11/15 Javascript
浅谈PDF.js使用心得
2018/06/07 Javascript
JS实现判断图片是否加载完成的方法分析
2018/07/31 Javascript
基于JS抓取某高校附近共享单车位置 使用web方式展示位置变化代码实例
2019/08/27 Javascript
javascript(基于jQuery)实现鼠标获取选中的文字示例【测试可用】
2019/10/26 jQuery
Vue关于组件化开发知识点详解
2020/05/13 Javascript
微信小程序实现弹框效果
2020/05/26 Javascript
Vue+Spring Boot简单用户登录(附Demo)
2020/11/12 Javascript
[01:06:25]Secret vs Liquid 2018国际邀请赛淘汰赛BO3 第一场 8.25
2018/08/29 DOTA
Python转换HTML到Text纯文本的方法
2015/01/15 Python
以一段代码为实例快速入门Python2.7
2015/03/31 Python
Python如何实现守护进程的方法示例
2017/02/08 Python
Python在图片中添加文字的两种方法
2017/04/29 Python
Python企业编码生成系统总体系统设计概述
2019/07/26 Python
python实现简单的五子棋游戏
2020/09/01 Python
Python爬取梨视频的示例
2021/01/29 Python
HTML5是否真的可以取代Flash
2010/02/10 HTML / CSS
HTML5移动端开发中的Viewport标签及相关CSS用法解析
2016/04/15 HTML / CSS
美国汽配连锁巨头Pep Boys官网:轮胎更换、汽车维修服务和汽车零部件
2017/01/14 全球购物
Monki官网:斯堪的纳维亚的独立时尚品牌
2020/11/09 全球购物
航空学院求职信
2014/06/11 职场文书
小学数学国培研修日志
2015/11/13 职场文书
2016党员读书思廉心得体会
2016/01/23 职场文书
导游词之吉林吉塔
2019/11/11 职场文书