tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入浅析Python中join 和 split详解(推荐)
Jun 30 Python
使用Python操作excel文件的实例代码
Oct 15 Python
pandas 快速处理 date_time 日期格式方法
Nov 12 Python
Django Rest framework之认证的实现代码
Dec 17 Python
详解Python是如何实现issubclass的
Jul 24 Python
Pytorch evaluation每次运行结果不同的解决
Jan 02 Python
Python for循环通过序列索引迭代过程解析
Feb 07 Python
详解Tensorflow不同版本要求与CUDA及CUDNN版本对应关系
Aug 04 Python
理解Django 中Call Stack机制的小Demo
Sep 01 Python
属性与 @property 方法让你的python更高效
Sep 21 Python
Python开发.exe小工具的详细步骤
Jan 27 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
常用星际术语索引(新手指南)
2020/03/04 星际争霸
采用PHP函数memory_get_usage获取PHP内存清耗量的方法
2011/12/06 PHP
php打开文件fopen函数的使用说明
2013/07/05 PHP
php将csv文件导入到mysql数据库的方法
2014/12/24 PHP
php制作简单模版引擎
2016/04/07 PHP
Laravel框架实现抢红包功能示例
2019/10/31 PHP
javascript提取URL的搜索字符串中的参数(自定义函数实现)
2013/01/22 Javascript
Jquery 改变radio/checkbox选中状态,获取选中的值(示例代码)
2013/12/12 Javascript
node.js中的emitter.on方法使用说明
2014/12/10 Javascript
jquery中checkbox全选失效的解决方法
2014/12/26 Javascript
js动态修改表格行colspan列跨度的方法
2015/03/30 Javascript
AngularJS ng-bind-template 指令详解
2016/07/30 Javascript
H5手机端多文件上传预览插件
2017/04/21 Javascript
详解JavaScript基础知识(JSON、Function对象、原型、引用类型)
2018/01/16 Javascript
详解使用Next.js构建服务端渲染应用
2018/07/10 Javascript
JavaScript引用类型RegExp基本用法详解
2018/08/09 Javascript
详解angular2.x创建项目入门指令
2018/10/11 Javascript
深入理解vue中的slot与slot-scope
2019/04/22 Javascript
Angular8路由守卫原理和使用方法
2019/08/29 Javascript
Vue数字输入框组件使用方法详解
2020/02/10 Javascript
Python入门篇之条件、循环
2014/10/17 Python
Python下Fabric的简单部署方法
2015/07/14 Python
python获取命令行输入参数列表的实例代码
2018/06/23 Python
python定向爬虫校园论坛帖子信息
2018/07/23 Python
python中with用法讲解
2020/02/07 Python
python绘制动态曲线教程
2020/02/24 Python
pycharm 代码自动补全的实现方法(图文)
2020/09/18 Python
python爬虫使用scrapy注意事项
2020/11/23 Python
HTML5实现一个能够移动的小坦克示例代码
2013/09/02 HTML / CSS
美国在线宠物用品商店:Entirely Pets
2017/01/01 全球购物
乔迁宴答谢词
2014/01/21 职场文书
汽车队司机先进事迹材料
2014/02/01 职场文书
党的群众路线教育实践活动通讯稿
2014/09/10 职场文书
幼儿园卫生保健制度
2015/08/05 职场文书
2016年村党支部公开承诺书
2016/03/24 职场文书
css height属性中的calc方法详解
2021/06/03 HTML / CSS