Tensorflow中的dropout的使用方法


Posted in Python onMarch 13, 2020

Hinton在论文《Improving neural networks by preventing co-adaptation of feature detectors》中提出了Dropout。Dropout用来防止神经网络的过拟合。Tensorflow中可以通过如下3中方式实现dropout。

tf.nn.dropout

def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):

其中,x为浮点类型的tensor,keep_prob为浮点类型的scalar,范围在(0,1]之间,表示x中的元素被保留下来的概率,noise_shape为一维的tensor(int32类型),表示标记张量的形状(representing the shape for randomly generated keep/drop flags),并且noise_shape指定的形状必须对x的形状是可广播的。如果x的形状是[k, l, m, n],并且noise_shape为[k, l, m, n],那么x中的每一个元素是否保留都是独立,但如果x的形状是[k, l, m, n],并且noise_shape为[k, 1, 1, n],则x中的元素沿着第0个维度第3个维度以相互独立的概率保留或者丢弃,而元素沿着第1个维度和第2个维度要么同时保留,要么同时丢弃。

关于Tensorflow中的广播机制,可以参考《TensorFlow 和 NumPy 的 Broadcasting 机制探秘》

最终,会输出一个与x形状相同的张量ret,如果x中的元素被丢弃,则在ret中的对应位置元素为0,如果x中的元素被保留,则在ret中对应位置上的值为Tensorflow中的dropout的使用方法,这么做是为了使得ret中的元素之和等于x中的元素之和。

tf.layers.dropout

def dropout(inputs,
   rate=0.5,
   noise_shape=None,
   seed=None,
   training=False,
   name=None):

参数inputs为输入的张量,与tf.nn.dropout的参数keep_prob不同,rate指定元素被丢弃的概率,如果rate=0.1,则inputs中10%的元素将被丢弃,noise_shape与tf.nn.dropout的noise_shape一致,training参数用来指示当前阶段是出于训练阶段还是测试阶段,如果training为true(即训练阶段),则会进行dropout,否则不进行dropout,直接返回inputs。

自定义稀疏张量的dropout

上述的两种方法都是针对dense tensor的dropout,但有的时候,输入可能是稀疏张量,仿照tf.nn.dropout和tf.layers.dropout的内部实现原理,自定义稀疏张量的dropout。

def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)

其中,参数x和keep_prob与tf.nn.dropout一致,noise_shape为x中非空元素的个数,如果x中有4个非空值,则noise_shape为[4],keep_tensor的元素为[keep_prob, 1.0 + keep_prob)的均匀分布,通过tf.floor向下取整得到标记张量drop_mask,tf.sparse_retain用于在一个 SparseTensor 中保留指定的非空值。

案例

def nn_dropout(x, keep_prob, noise_shape):
 out = tf.nn.dropout(x, keep_prob, noise_shape)
 return out


def layers_dropout(x, keep_prob, noise_shape, training=False):
 out = tf.layers.dropout(x, keep_prob, noise_shape, training=training)
 return out


def sparse_dropout(x, keep_prob, noise_shape):
 keep_tensor = keep_prob + tf.random_uniform(noise_shape)
 drop_mask = tf.cast(tf.floor(keep_tensor), dtype=tf.bool)
 out = tf.sparse_retain(x, drop_mask)
 return out * (1.0/keep_prob)


if __name__ == '__main__':
 inputs1 = tf.SparseTensor(indices=[[0, 0], [0, 2], [1, 1], [1, 2]], values=[1.0, 2.0, 3.0, 4.0], dense_shape=[2, 3])
 inputs2 = tf.sparse_tensor_to_dense(inputs1)
 nn_d_out = nn_dropout(inputs2, 0.5, [2, 3])
 layers_d_out = layers_dropout(inputs2, 0.5, [2, 3], training=True)
 sparse_d_out = sparse_dropout(inputs1, 0.5, [4])
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  (in1, in2) = sess.run([inputs1, inputs2])
  print(in1)
  print(in2)
  (out1, out2, out3) = sess.run([nn_d_out, layers_d_out, sparse_d_out])
  print(out1)
  print(out2)
  print(out3)

tensorflow中,稀疏张量为SparseTensor,稀疏张量的值为SparseTensorValue。3种dropout的输出如下,

SparseTensorValue(indices=array([[0, 0],
  [0, 2],
  [1, 1],
  [1, 2]], dtype=int64), values=array([ 1., 2., 3., 4.], dtype=float32), dense_shape=array([2, 3], dtype=int64))
[[ 1. 0. 2.]
 [ 0. 3. 4.]]
 
[[ 2. 0. 0.]
 [ 0. 0. 0.]]
[[ 0. 0. 4.]
 [ 0. 0. 0.]]
SparseTensorValue(indices=array([], shape=(0, 2), dtype=int64), values=array([], dtype=float32), dense_shape=array([2, 3], dtype=int64))

到此这篇关于Tensorflow中的dropout的使用方法的文章就介绍到这了,更多相关Tensorflow dropout内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python动态监控日志内容的示例
Feb 16 Python
python教程之用py2exe将PY文件转成EXE文件
Jun 12 Python
在Python中使用__slots__方法的详细教程
Apr 28 Python
Python编程中字符串和列表的基本知识讲解
Oct 14 Python
Python使用sorted排序的方法小结
Jul 28 Python
在Django中输出matplotlib生成的图片方法
May 24 Python
解决Django migrate No changes detected 不能创建表的问题
May 27 Python
Django ORM 常用字段与不常用字段汇总
Aug 09 Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 Python
python快速安装OpenCV的步骤记录
Feb 22 Python
python 如何用terminal输入参数
May 25 Python
解决pycharm下载库时出现Failed to install package的问题
Sep 04 Python
python实现简单俄罗斯方块
Mar 13 #Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 #Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 #Python
python 两个一样的字符串用==结果为false问题的解决
Mar 12 #Python
python不相等的两个字符串的 if 条件判断为True详解
Mar 12 #Python
Python 实现使用空值进行赋值 None
Mar 12 #Python
PyCharm永久激活方式(推荐)
Sep 22 #Python
You might like
php获得文件扩展名三法
2006/11/25 PHP
PHP删除数组中空值的方法介绍
2014/04/14 PHP
ThinkPHP公共配置文件与各自项目中配置文件组合的方法
2014/11/24 PHP
PHP实现的简单网络硬盘
2015/07/29 PHP
PHP实现多维数组转字符串和多维数组转一维数组的方法
2015/08/08 PHP
PHP使用mkdir创建多级目录的方法
2015/12/22 PHP
基于yaf框架和uploadify插件,做的一个导入excel文件,查看并保存数据的功能
2017/01/24 PHP
xtree.js 代码
2007/03/13 Javascript
抽出www.templatemonster.com的鼠标悬停加载大图模板的代码
2007/07/11 Javascript
Extjs 几个方法的讨论
2010/01/28 Javascript
javascript实现的像java、c#之类的sleep暂停的函数代码
2010/03/04 Javascript
浅谈Javascript嵌套函数及闭包
2010/11/09 Javascript
JavaScript入门之对象与JSON详解
2011/10/21 Javascript
jQuery选择器源码解读(三):tokenize方法
2015/03/31 Javascript
从重置input file标签中看jQuery的 .val() 和 .attr(“value”) 区别
2016/06/12 Javascript
jQuery延迟执行的实现方法
2016/12/21 Javascript
使用Bootstrap4 + Vue2实现分页查询的示例代码
2017/12/21 Javascript
Angular实现下拉框模糊查询功能示例
2018/01/03 Javascript
JS实现的DOM插入节点操作示例
2018/04/04 Javascript
vue采用EventBus实现跨组件通信及注意事项小结
2018/06/14 Javascript
监听angularJs列表数据是否渲染完毕的方法示例
2018/11/07 Javascript
小程序实现抽奖动画
2020/04/16 Javascript
Vue使用axios出现options请求方法
2019/05/30 Javascript
Taro UI框架开发小程序实现左滑喜欢右滑不喜欢效果的示例代码
2020/05/18 Javascript
Python实现的下载8000首儿歌的代码分享
2014/11/21 Python
Django数据库连接丢失问题的解决方法
2018/12/29 Python
python使用正则来处理各种匹配问题
2019/12/22 Python
基于python实现语音录入识别代码实例
2020/01/17 Python
英国专业美容产品在线:Mylee(从指甲到脱毛)
2020/07/06 全球购物
《一件运动衫》教学反思
2014/02/19 职场文书
精神文明建设先进工作者事迹材料
2014/05/02 职场文书
幼儿园教师师德师风演讲稿:爱我所爱 无悔青春
2014/09/10 职场文书
会计人员演讲稿
2014/09/11 职场文书
人事任命通知
2015/04/20 职场文书
Java中CyclicBarrier和CountDownLatch的用法与区别
2021/08/23 Java/Android
在虚拟机中安装windows server 2008的图文教程
2022/06/28 Servers