解决TensorFlow调用Keras库函数存在的问题


Posted in Python onJuly 06, 2020

tensorflow在1.4版本引入了keras,封装成库。现想将keras版本的GRU代码移植到TensorFlow中,看到TensorFlow中有Keras库,大喜,故将神经网络定义部分使用Keras的Function API方式进行定义,训练部分则使用TensorFlow来进行编写。一顿操作之后,运行,没有报错,不由得一喜。但是输出结果,发现,和预期的不一样。难道是欠拟合?故采用正弦波预测余弦来验证算法模型。

部分调用keras库代码如上图所示,用正弦波预测余弦波,出现如下现象:

def interface(_input):
  tmp = tf.keras.layers.Dense(10)(_input)
  vad_gru = tf.keras.layers.GRU(24, return_sequences=True)(tmp)
  denoise_output = tf.keras.layers.Dense(1)(vad_gru)
  return denoise_output

波形是断断续续的。而且最后不收敛。

解决TensorFlow调用Keras库函数存在的问题

运行N久。。。之后

基本断定是程序本身的问题,于是通过排查,发现应该是GRU的initial_state没有进行更新导致的。导致波形是断断续续的,没有学习到前一次网络的输出。于是,决定不使用Keras库实现一遍:

部分代码如下:

def interface(_input):
  tmp = tf.keras.layers.Dense(10)(_input)
  gru_cell = tf.nn.rnn_cell.GRUCell(vad_cell_size)
  with tf.name_scope('initial_state'):
    cell_init_state = gru_cell.zero_state(batch_size, dtype=tf.float32)
  cell_outputs, cell_final_state = tf.nn.dynamic_rnn(
    gru_cell, tmp, initial_state=cell_init_state, time_major=False)
  denoise_output = tf.keras.layers.Dense(1)(cell_outputs)
  return denoise_output, cell_init_state, cell_final_state

波形图如下(这才是GRU的正确打开方式啊~):

解决TensorFlow调用Keras库函数存在的问题

再回头看之前写的调用keras,既然知道了是initial_state没有更新,那么如何进行更新呢?

网上查找了大量的资料,说要加上

update_ops = []
for old_value, new_value in layers.updates:
  update_ops.append(tf.assign(old_value, new_value))

但是加上去没有效果,是我加错了还是其他的,大家欢迎指出来

以下是我做的一些尝试,就不一一详细说明了,大家看一下,具体不再展开,有问题大家交流一下,有解决方法的,能够分享出来,感激不尽~

def interface(_input):
  # input_layer = tf.keras.layers.Input([None, 1])
  # input_layer = tf.keras.layers.Input(batch_shape=(50, 20, 1))
  tmp = tf.keras.layers.Dense(10)(_input)
  # tmp = tf.keras.layers.Dense(24)(tmp)
 
  # with tf.variable_scope('vad_gru', reuse=tf.AUTO_REUSE):
  # vad_gru, final_state = tf.keras.layers.GRU(24, return_sequences=True, return_state=True, stateful=True)(tmp)
  # print(vad_gru)
  # _initial_state = vad_gru.zero_state(50, tf.float32)
  # tf.get_variable_scope().reuse_variables()
 
  # vad_gru = tf.contrib.
 
  # tmp = tf.reshape(tmp, [-1, TIME_STEPS, vad_cell_size])
  gru_cell = tf.nn.rnn_cell.GRUCell(vad_cell_size)
  # gru_cell = tf.keras.layers.GRUCell(self.vad_cell_size)
  with tf.name_scope('initial_state'):
    cell_init_state = gru_cell.zero_state(batch_size, dtype=tf.float32)
  cell_outputs, cell_final_state = tf.nn.dynamic_rnn(
    gru_cell, tmp, initial_state=cell_init_state, time_major=False)
  # print(cell_outputs.get_shape().as_list())
 
  # cell_outputs = tf.reshape(cell_outputs, [-1, vad_cell_size])
 
  denoise_output = tf.keras.layers.Dense(1)(cell_outputs)
  print(denoise_output.get_shape().as_list())
 
  # model = tf.keras.models.Model(input_layer, denoise_output)
  # update_ops = []
  # for old_value, new_value in model.layers[1].updates:
  #   update_ops.append(tf.assign(old_value, new_value))
 
  return denoise_output, cell_init_state, cell_final_state

补充知识:TensorFlow和Keras常用方法(避坑)

TensorFlow

在TensorFlow中,除法运算:

1.tensor除法会使结果的精度高一级,可能会导致后面计算类型不匹配,如float32 / float32 = float64。

2.除法需要分子分母同类型,否则报错。

产生类似错误提示如下:

-1.TypeError: x and y must have the same dtype, got tf.float32 != tf.int32

-2.TypeError: Input ‘y' of ‘Mul' Op has type float32 that does not match type float64 of argument ‘x'.

-3.ValueError: Tensor conversion requested dtype float64 for Tensor with dtype float32: ‘Tensor(“Sum:0”, shape=(), dtype=float32)'

-4.ValueError: Incompatible type conversion requested to type ‘int32' for variable of type ‘float32_ref'

解决办法:

tf.cast(a, tf.float32) # 转换成同类型即可

tf.boolean_mask

K.gather

K.argmax

K.max

以上这篇解决TensorFlow调用Keras库函数存在的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中的多进程与多线程的使用
Apr 07 Python
python检查字符串是否是正确ISBN的方法
Jul 11 Python
PYTHON 中使用 GLOBAL引发的一系列问题
Oct 12 Python
对Python中range()函数和list的比较
Apr 19 Python
利用python如何处理nc数据详解
May 23 Python
对Python 除法负数取商的取整方式详解
Dec 12 Python
Python使用微信itchat接口实现查看自己微信的信息功能详解
Aug 22 Python
对tensorflow中的strides参数使用详解
Jan 04 Python
Python使用plt.boxplot() 参数绘制箱线图
Jun 04 Python
Python selenium爬虫实现定时任务过程解析
Jun 08 Python
用 python 进行微信好友信息分析
Nov 28 Python
Pygame Draw绘图函数的具体使用
Nov 17 Python
python else语句在循环中的运用详解
Jul 06 #Python
Keras模型转成tensorflow的.pb操作
Jul 06 #Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
You might like
PHP连接MySQL数据的操作要点
2015/03/20 PHP
详细解读PHP中接口的应用
2015/08/12 PHP
thinkPHP中volist标签用法示例
2016/12/06 PHP
php生成毫秒时间戳的实例讲解
2017/09/22 PHP
提高Laravel应用性能方法详解
2019/06/24 PHP
PHP pthreads v3下同步处理synchronized用法示例
2020/02/21 PHP
cssQuery()的下载与使用方法
2007/01/12 Javascript
lyhucSelect基于Jquery的Select数据联动插件
2011/03/29 Javascript
jQuery判断iframe中元素是否存在的方法
2013/05/11 Javascript
快速解决FusionCharts联动的中文乱码问题
2013/12/04 Javascript
jquery滚动特效集锦
2015/06/03 Javascript
跟我学习javascript的严格模式
2015/11/16 Javascript
jQuery获取复选框被选中数量及判断选择值的方法详解
2016/05/25 Javascript
JavaScript无阻塞加载和defer、async详解
2017/02/26 Javascript
js数字舍入误差以及解决方法(必看篇)
2017/02/28 Javascript
seajs中最常用的7个功能、配置示例
2017/10/10 Javascript
关于Vue的路由权限管理的示例代码
2018/03/06 Javascript
setTimeout时间设置为0详细解析
2018/03/13 Javascript
Vue 全家桶实现移动端酷狗音乐功能
2018/11/16 Javascript
Vue.js的动态组件模板的实现
2018/11/26 Javascript
PWA介绍及快速上手搭建一个PWA应用的方法
2019/01/27 Javascript
JS实现普通轮播图特效
2020/01/01 Javascript
[03:12]TI9战队档案 - Virtus Pro
2019/08/20 DOTA
python3简单实现微信爬虫
2015/04/09 Python
Python下使用Scrapy爬取网页内容的实例
2018/05/21 Python
Python用摘要算法生成token及检验token的示例代码
2020/12/01 Python
美国最大的家庭鞋类零售商之一:Shoe Carnival
2017/10/06 全球购物
Skyscanner英国:苏格兰的全球三大领先航班搜索服务之一
2017/11/09 全球购物
新西兰最大、占有率最高的综合性药房:PharmacyDirect药房中文网
2020/11/03 全球购物
short s1 = 1; s1 = s1 + 1;有什么错? short s1 = 1; s1 += 1;有什么错?
2014/09/26 面试题
乡镇党员干部群众路线对照检查材料思想汇报
2014/09/28 职场文书
省委召开党的群众路线教育实践活动总结大会报告
2014/10/21 职场文书
领导视察通讯稿
2015/07/18 职场文书
创业计划书之餐饮
2019/09/02 职场文书
解决redis批量删除key值的问题
2022/03/23 Redis
Python如何快速找到多个字典中的公共键(key)
2022/04/29 Python