关于tf.nn.dynamic_rnn返回值详解


Posted in Python onJanuary 20, 2020

函数原型

tf.nn.dynamic_rnn(
  cell,
  inputs,
  sequence_length=None,
  initial_state=None,
  dtype=None,
  parallel_iterations=None,
  swap_memory=False,
  time_major=False,
  scope=None
)

实例讲解:

import tensorflow as tf
import numpy as np
 
n_steps = 2
n_inputs = 3
n_neurons = 5
 
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
 
seq_length = tf.placeholder(tf.int32, [None])
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32,
                  sequence_length=seq_length)
 
init = tf.global_variables_initializer()
 
X_batch = np.array([
    # step 0   step 1
    [[0, 1, 2], [9, 8, 7]], # instance 1
    [[3, 4, 5], [0, 0, 0]], # instance 2 (padded with zero vectors)
    [[6, 7, 8], [6, 5, 4]], # instance 3
    [[9, 0, 1], [3, 2, 1]], # instance 4
  ])
seq_length_batch = np.array([2, 1, 2, 2])
 
with tf.Session() as sess:
  init.run()
  outputs_val, states_val = sess.run(
    [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
  print("outputs_val.shape:", outputs_val.shape, "states_val.shape:", states_val.shape)
  print("outputs_val:", outputs_val, "states_val:", states_val)

log info:

outputs_val.shape: (4, 2, 5) states_val.shape: (4, 5)
outputs_val: 
[[[ 0.53073734 -0.61281306 -0.5437517  0.7320347 -0.6109526 ]
 [ 0.99996936 0.99990636 -0.9867181  0.99726075 -0.99999976]]
 
 [[ 0.9931584  0.5877845 -0.9100412  0.988892  -0.9982337 ]
 [ 0.     0.     0.     0.     0.    ]]
 
 [[ 0.99992317 0.96815354 -0.985101  0.9995968 -0.9999936 ]
 [ 0.99948144 0.9998127 -0.57493806 0.91015154 -0.99998355]]
 
 [[ 0.99999255 0.9998929  0.26732785 0.36024097 -0.99991137]
 [ 0.98875254 0.9922327  0.6505734  0.4732064 -0.9957567 ]]] 
states_val:
 [[ 0.99996936 0.99990636 -0.9867181  0.99726075 -0.99999976]
 [ 0.9931584  0.5877845 -0.9100412  0.988892  -0.9982337 ]
 [ 0.99948144 0.9998127 -0.57493806 0.91015154 -0.99998355]
 [ 0.98875254 0.9922327  0.6505734  0.4732064 -0.9957567 ]]

首先输入X是一个 [batch_size,step,input_size] = [4,2,3] 的tensor,注意我们这里调用的是BasicRNNCell,只有一层循环网络,outputs是最后一层每个step的输出,它的结构是[batch_size,step,n_neurons] = [4,2,5],states是每一层的最后那个step的输出,由于本例中,我们的循环网络只有一个隐藏层,所以它就代表这一层的最后那个step的输出,因此它和step的大小是没有关系的,我们的X有4个样本组成,输出神经元大小n_neurons是5,因此states的结构就是[batch_size,n_neurons] = [4,5],最后我们观察数据,states的每条数据正好就是outputs的最后一个step的输出。

下面我们继续讲解多个隐藏层的情况,这里是三个隐藏层,注意我们这里仍然是调用BasicRNNCell

import tensorflow as tf
import numpy as np
 
n_steps = 2
n_inputs = 3
n_neurons = 5
n_layers = 3
 
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])
 
layers = [tf.contrib.rnn.BasicRNNCell(num_units=n_neurons,
                   activation=tf.nn.relu)
     for layer in range(n_layers)]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32, sequence_length=seq_length)
 
init = tf.global_variables_initializer()
 
X_batch = np.array([
    # step 0   step 1
    [[0, 1, 2], [9, 8, 7]], # instance 1
    [[3, 4, 5], [0, 0, 0]], # instance 2 (padded with zero vectors)
    [[6, 7, 8], [6, 5, 4]], # instance 3
    [[9, 0, 1], [3, 2, 1]], # instance 4
  ])
 
seq_length_batch = np.array([2, 1, 2, 2])
 
with tf.Session() as sess:
  init.run()
  outputs_val, states_val = sess.run(
    [outputs, states], feed_dict={X: X_batch, seq_length: seq_length_batch})
  print("outputs_val.shape:", outputs, "states_val.shape:", states)
  print("outputs_val:", outputs_val, "states_val:", states_val)

log info:

outputs_val.shape: 
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32) 
 
states_val.shape: 
(<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>, 
 <tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>, 
 <tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>)
 
outputs_val:
 [[[0.     0.     0.     0.     0.    ]
 [0.     0.18740742 0.     0.2997518 0.    ]]
 
 [[0.     0.07222144 0.     0.11551574 0.    ]
 [0.     0.     0.     0.     0.    ]]
 
 [[0.     0.13463384 0.     0.21534224 0.    ]
 [0.03702604 0.18443246 0.     0.34539366 0.    ]]
 
 [[0.     0.54511094 0.     0.8718864 0.    ]
 [0.5382122 0.     0.04396425 0.4040263 0.    ]]] 
 
states_val:
 (array([[0.    , 0.83723307, 0.    , 0.    , 2.8518028 ],
    [0.    , 0.1996038 , 0.    , 0.    , 1.5456247 ],
    [0.    , 1.1372368 , 0.    , 0.    , 0.832613 ],
    [0.    , 0.7904129 , 2.4675028 , 0.    , 0.36980057]],
   dtype=float32), 
 array([[0.6524607 , 0.    , 0.    , 0.    , 0.    ],
    [0.25143963, 0.    , 0.    , 0.    , 0.    ],
    [0.5010576 , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.3166597 , 0.4545995 , 0.    , 0.    ]],
   dtype=float32), 
 array([[0.    , 0.18740742, 0.    , 0.2997518 , 0.    ],
    [0.    , 0.07222144, 0.    , 0.11551574, 0.    ],
    [0.03702604, 0.18443246, 0.    , 0.34539366, 0.    ],
    [0.5382122 , 0.    , 0.04396425, 0.4040263 , 0.    ]],
   dtype=float32))

我们说过,outputs是最后一层的输出,即 [batch_size,step,n_neurons] = [4,2,5]

states是每一层的最后一个step的输出,即三个结构为 [batch_size,n_neurons] = [4,5] 的tensor

继续观察数据,states中的最后一个array,正好是outputs的最后那个step的输出

下面我们继续讲当由BasicLSTMCell构造单元工厂的时候,只讲多层的情况,我们只需要将上面的BasicRNNCell替换成BasicLSTMCell就行了,打印信息如下:

outputs_val.shape: 
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32) 
 
states_val.shape:
(LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>), 
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(?, 5) dtype=float32>), 
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(?, 5) dtype=float32>, 
        h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(?, 5) dtype=float32>))
 
outputs_val: 
[[[1.2949290e-04 0.0000000e+00 2.7623639e-04 0.0000000e+00 0.0000000e+00]
 [9.4675866e-05 0.0000000e+00 2.0214770e-04 0.0000000e+00 0.0000000e+00]]
 
 [[4.3100454e-06 4.2123037e-07 1.4312843e-06 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
 
 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]
 
 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]]] 
 
states_val: 
(LSTMStateTuple(
c=array([[0.    , 0.    , 0.04676079, 0.04284539, 0.    ],
    [0.    , 0.    , 0.0115245 , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ]],
   dtype=float32), 
h=array([[0.    , 0.    , 0.00035096, 0.04284406, 0.    ],
    [0.    , 0.    , 0.00142574, 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ],
    [0.    , 0.    , 0.    , 0.    , 0.    ]],
   dtype=float32)), 
LSTMStateTuple(
c=array([[0.0000000e+00, 1.0477135e-02, 4.9871090e-03, 8.2785974e-04,
    0.0000000e+00],
    [0.0000000e+00, 2.3306280e-04, 0.0000000e+00, 9.9445322e-05,
    5.9535629e-05],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32), 
h=array([[0.00000000e+00, 5.23016974e-03, 2.47756205e-03, 4.11730434e-04,
    0.00000000e+00],
    [0.00000000e+00, 1.16522635e-04, 0.00000000e+00, 4.97301044e-05,
    2.97713632e-05],
    [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
    [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00]], dtype=float32)), 
LSTMStateTuple(
c=array([[1.8937115e-04, 0.0000000e+00, 4.0442235e-04, 0.0000000e+00,
    0.0000000e+00],
    [8.6200516e-06, 8.4243663e-07, 2.8625946e-06, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32), 
h=array([[9.4675866e-05, 0.0000000e+00, 2.0214770e-04, 0.0000000e+00,
    0.0000000e+00],
    [4.3100454e-06, 4.2123037e-07, 1.4312843e-06, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00],
    [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
    0.0000000e+00]], dtype=float32)))

我们先看看LSTM单元的结构

关于tf.nn.dynamic_rnn返回值详解

如果您不查看框内的内容,LSTM单元看起来与常规单元格完全相同,除了它的状态分为两个向量:h(t)和c(t)。你可以将h(t)视为短期状态,将c(t)视为长期状态。

因此我们的states包含三个LSTMStateTuple,每一个表示每一层的最后一个step的输出,这个输出有两个信息,一个是h表示短期记忆信息,一个是c表示长期记忆信息。维度都是[batch_size,n_neurons] = [4,5],states的最后一个LSTMStateTuple中的h就是outputs的最后一个step的输出

以上这篇关于tf.nn.dynamic_rnn返回值详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用CMD模块更优雅的运行脚本
May 11 Python
python在windows下创建隐藏窗口子进程的方法
Jun 04 Python
python实现稀疏矩阵示例代码
Jun 09 Python
python实现感知器
Dec 19 Python
matplotlib作图添加表格实例代码
Jan 23 Python
python字典快速保存于读取的方法
Mar 23 Python
Python中垃圾回收和del语句详解
Nov 15 Python
python操作excel让工作自动化
Aug 09 Python
解决springboot yml配置 logging.level 报错问题
Feb 21 Python
Python 日期时间datetime 加一天,减一天,加减一小时一分钟,加减一年
Apr 16 Python
Python2.x与3​​.x版本有哪些区别
Jul 09 Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
Jan 20 #Python
关于tf.reverse_sequence()简述
Jan 20 #Python
tensorflow使用range_input_producer多线程读取数据实例
Jan 20 #Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
Jan 20 #Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 #Python
python机器学习库xgboost的使用
Jan 20 #Python
python 爬取马蜂窝景点翻页文字评论的实现
Jan 20 #Python
You might like
怎样辨别一杯好咖啡
2021/03/03 新手入门
第四节--构造函数和析构函数
2006/11/16 PHP
php数组操作之键名比较与差集、交集赋值的方法
2014/11/10 PHP
PHP结合Jquery和ajax实现瀑布流特效
2016/01/07 PHP
Smarty模板引擎缓存机制详解
2016/05/23 PHP
PHP怎样用正则抓取页面中的网址
2016/08/09 PHP
PHP+Ajax简单get验证操作示例
2019/03/02 PHP
Laravel 自带的Auth验证登录方法
2019/09/30 PHP
使用JS 清空File控件的路径值
2013/07/08 Javascript
ext前台接收action传过来的json数据示例
2014/06/17 Javascript
原生js和jQuery随意改变div属性style的名称和值
2014/10/22 Javascript
JS实现简单路由器功能的方法
2015/05/27 Javascript
Jquery中巧用Ajax的beforeSend方法
2016/01/20 Javascript
js a标签点击事件
2017/03/30 Javascript
react.js CMS 删除功能的实现方法
2017/04/17 Javascript
使用prop解决一个checkbox选中后再次选中失效的问题
2017/07/05 Javascript
微信小程序实现获取自己所处位置的经纬度坐标功能示例
2017/11/30 Javascript
vue2.0使用swiper组件实现轮播的示例代码
2018/03/03 Javascript
elementUI select组件使用及注意事项详解
2019/05/29 Javascript
使用axios请求接口,几种content-type的区别详解
2019/10/29 Javascript
JSON获取属性值方法代码实例
2020/06/30 Javascript
Vue实现boradcast和dispatch的示例
2020/11/13 Javascript
python用Pygal如何生成漂亮的SVG图像详解
2017/02/10 Python
python实战之实现excel读取、统计、写入的示例讲解
2018/05/02 Python
pandas表连接 索引上的合并方法
2018/06/08 Python
浅谈Python里面小数点精度的控制
2018/07/16 Python
Python实现的读取/更改/写入xml文件操作示例
2018/08/30 Python
对Python3.x版本print函数左右对齐详解
2018/12/22 Python
浅谈python常用程序算法
2019/03/22 Python
Python 3.8正式发布,来尝鲜这些新特性吧
2019/10/15 Python
使用CSS3实现一个3D相册效果实例
2016/12/03 HTML / CSS
会计电算化个人求职信范文
2014/01/24 职场文书
二年级语文下册复习计划
2015/01/19 职场文书
《草虫的村落》教学反思
2016/02/20 职场文书
CSS的class与id常用的命名规则
2021/05/18 HTML / CSS
vue如何在data中引入图片的正确路径
2022/06/05 Vue.js