浅谈keras中Dropout在预测过程中是否仍要起作用


Posted in Python onJuly 09, 2020

因为需要,要重写训练好的keras模型,虽然只具备预测功能,但是发现还是有很多坑要趟过。其中Dropout这个坑,我记忆犹新。

一开始,我以为预测时要保持和训练时完全一样的网络结构,也就是预测时用的网络也是有丢弃的网络节点,但是这样想就掉进了一个大坑!因为无法通过已经训练好的模型,来获取其训练时随机丢弃的网络节点是那些,这本身就根本不可能。

更重要的是:我发现每一个迭代周期丢弃的神经元也不完全一样。

假若迭代500次,网络共有1000个神经元, 在第n(1<= n <500)个迭代周期内,从1000个神经元里随机丢弃了200个神经元,在n+1个迭代周期内,会在这1000个神经元里(不是在剩余得800个)重新随机丢弃200个神经元。

训练过程中,使用Dropout,其实就是对部分权重和偏置在某次迭代训练过程中,不参与计算和更新而已,并不是不再使用这些权重和偏置了(预测时,会使用全部的神经元,包括使用训练时丢弃的神经元)。

也就是说在预测过程中完全没有Dropout什么事了,他只是在训练时有用,特别是针对训练集比较小时防止过拟合非常有用。

补充知识:TensorFlow直接使用ckpt模型predict不用restore

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
# from util import *
import cv2
import numpy as np
import tensorflow as tf
# from tensorflow.python.framework import graph_util
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
image_path = './8760.pgm'

input_checkpoint = './model/xu_spatial_model_1340.ckpt'

sess = tf.Session()
saver = tf.train.import_meta_graph(input_checkpoint + '.meta')
saver.restore(sess, input_checkpoint)

# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("coef_input:0")
is_training = sess.graph.get_tensor_by_name('is_training:0')
batch_size = sess.graph.get_tensor_by_name('batch_size:0')
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("xuNet/logits:0") # xuNet/Logits/logits
image = cv2.imread(image_path, 0)
# 读取测试图片
out = sess.run(output_tensor_name, feed_dict={input_image_tensor: np.reshape(image, (1, 512, 512, 1)),
                       is_training: False,
                       batch_size: 1})
print(out)

ckpt模型中的所有节点名称,可以这样查看

[n.name for n in tf.get_default_graph().as_graph_def().node]

以上这篇浅谈keras中Dropout在预测过程中是否仍要起作用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python目录操作之python遍历文件夹后将结果存储为xml
Jan 27 Python
Python深入学习之内存管理
Aug 31 Python
python验证码识别的示例代码
Sep 21 Python
python3使用requests模块爬取页面内容的实战演练
Sep 25 Python
Django添加KindEditor富文本编辑器的使用
Oct 24 Python
Python数据分析模块pandas用法详解
Sep 04 Python
Pycharm+django2.2+python3.6+MySQL实现简单的考试报名系统
Sep 05 Python
python 用 xlwings 库 生成图表的操作方法
Dec 22 Python
浅谈python累加求和+奇偶数求和_break_continue
Feb 25 Python
python GUI库图形界面开发之PyQt5信号与槽的高级使用技巧(自定义信号与槽)详解与实例
Mar 06 Python
Opencv求取连通区域重心实例
Jun 04 Python
Python 键盘事件详解
Nov 11 Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
Django中Aggregation聚合的基本使用方法
Jul 09 #Python
Python  word实现读取及导出代码解析
Jul 09 #Python
推荐技术人员一款Python开源库(造数据神器)
Jul 08 #Python
实例讲解Python 迭代器与生成器
Jul 08 #Python
opencv 阈值分割的具体使用
Jul 08 #Python
You might like
使用php将某个目录下面的所有文件罗列出来的方法详解
2013/06/21 PHP
php自定义urlencode,urldecode函数实例
2015/03/24 PHP
php微信开发之百度天气预报
2016/11/18 PHP
thinkPHP自动验证机制详解
2016/12/05 PHP
如何通过View::first使用Laravel Blade的动态模板详解
2017/09/21 PHP
修复IE9&amp;safari 的sort方法
2011/10/21 Javascript
JS代码放在head和body中的区别分析
2011/12/01 Javascript
jquery仿京东导航/仿淘宝商城左侧分类导航下拉菜单效果
2013/04/24 Javascript
jquery 实现窗口的最大化不论什么情况
2013/09/03 Javascript
js弹出确认是否删除对话框
2014/03/27 Javascript
jQuery实现图片渐入渐出切换展示效果
2015/08/15 Javascript
Bootstrap实现带动画过渡的弹出框
2016/08/09 Javascript
Bootstrap table 定制提示语的加载过程
2017/02/20 Javascript
node.js利用redis数据库缓存数据的方法
2017/03/01 Javascript
JavaScript 保护变量不被随意修改的实现代码
2017/09/27 Javascript
JS实现键值对遍历json数组功能示例
2018/05/30 Javascript
Vue.js中的高级面试题及答案
2020/01/13 Javascript
JS数组Reduce方法功能与用法实例详解
2020/04/29 Javascript
nuxt 服务器渲染动态设置 title和seo关键字的操作
2020/11/05 Javascript
在Python的Flask框架下收发电子邮件的教程
2015/04/21 Python
Python守护进程用法实例分析
2015/06/04 Python
动感网页相册 python编写简单文件夹内图片浏览工具
2016/08/17 Python
python如何将图片转换为字符图片
2020/08/19 Python
Python实现账号密码输错三次即锁定功能简单示例
2019/03/29 Python
Python中typing模块与类型注解的使用方法
2019/08/05 Python
pandas 对日期类型数据的处理方法详解
2019/08/08 Python
Django对接支付宝实现支付宝充值金币功能示例
2019/12/17 Python
英国家用电器购物网站:Hughes
2018/02/23 全球购物
DOM和JQuery对象有什么区别
2016/11/11 面试题
药品质量检测应届生求职信
2013/11/14 职场文书
平面设计求职信
2014/03/10 职场文书
学生不讲诚信检讨书
2014/09/29 职场文书
HTML基础-标签分类(闭合标签,空标签,块级元素,行内元素,行级块元素,可替换元素)
2021/03/31 HTML / CSS
Python基础学习之奇异的GUI对话框
2021/05/27 Python
如何在pycharm中快捷安装pip命令(如pygame)
2021/05/31 Python
go开发alertmanger实现钉钉报警
2021/07/16 Golang