tensorflow训练中出现nan问题的解决


Posted in Python onFebruary 10, 2018

深度学习中对于网络的训练是参数更新的过程,需要注意一种情况就是输入数据未做归一化时,如果前向传播结果已经是[0,0,0,1,0,0,0,0]这种形式,而真实结果是[1,0,0,0,0,0,0,0,0],此时由于得出的结论不惧有概率性,而是错误的估计值,此时反向传播会使得权重和偏置值变的无穷大,导致数据溢出,也就出现了nan的问题。

解决办法:

1、对输入数据进行归一化处理,如将输入的图片数据除以255将其转化成0-1之间的数据;

2、对于层数较多的情况,各层都做batch_nomorlization;

3、对设置Weights权重使用tf.truncated_normal(0, 0.01, [3,3,1,64])生成,同时值的均值为0,方差要小一些;

4、激活函数可以使用tanh;

5、减小学习率lr。

实例:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('data',one_hot = True)

def add_layer(input_data,in_size, out_size,activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size,out_size]))
  Biases = tf.Variable(tf.zeros([1, out_size])+0.1)
  Wx_plus_b = tf.add(tf.matmul(input_data, Weights), Biases)
  if activation_function==None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  #return outputs#, Weights
  return {'outdata':outputs, 'w':Weights}

def get_accuracy(t_y):
#  global l1
#  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(l1['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  global prediction
  accu = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(prediction['outdata'],1),tf.argmax(t_y,1)), dtype = tf.float32))
  return accu

X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

#l1 = add_layer(X, 784, 10, tf.nn.softmax)
#cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(l1['outdata']), reduction_indices= [1]))
#l1 = add_layer(X, 784, 1024, tf.nn.relu)

l1 = add_layer(X, 784, 1024, None)
prediction = add_layer(l1['outdata'], 1024, 10, tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(prediction['outdata']), reduction_indices= [1]))

optimizer = tf.train.GradientDescentOptimizer(0.000001)
train = optimizer.minimize(cross_entropy)


newW = tf.Variable(tf.random_normal([1024,10]))
newOut = tf.matmul(l1['outdata'],newW)
newSoftMax = tf.nn.softmax(newOut)

init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  #print(sess.run(l1_Weights))
  for i in range(2):
    X_train, y_train = mnist.train.next_batch(1)
    X_train = X_train/255  #需要进行归一化处理
    #print(sess.run(l1['w'],feed_dict={X:X_train}))
    #print(sess.run(prediction['w'],feed_dict={X:X_train, Y:y_train}))
    #print(sess.run(l1['outdata'],feed_dict={X:X_train, Y:y_train}).shape)
    print(sess.run(prediction['outdata'],feed_dict={X:X_train, Y:y_train}))
    print(sess.run(newOut, feed_dict={X:X_train}))
    print(sess.run(newSoftMax, feed_dict={X:X_train}))
    print(y_train)
    #print(sess.run(l1['outdata'], feed_dict={X:X_train}))
    sess.run(train, feed_dict={X:X_train, Y:y_train})
    if i%100 == 0:
      #print(sess.run(cross_entropy, feed_dict={X:X_train, Y:y_train}))
      accuracy = get_accuracy(mnist.test.labels)
      print(sess.run(accuracy,feed_dict={X:mnist.test.images}))
    
    #if i%100==0:
    #print(sess.run(prediction, feed_dict={X:X_train}))
    #print(sess.run(cross_entropy, feed_dict={X:X_train,Y:y_train}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Swift 3.0在集合类数据结构上的一些新变化总结
Jul 11 Python
Python人脸识别初探
Dec 21 Python
Python3.5 创建文件的简单实例
Apr 26 Python
Django框架登录加上验证码校验实现验证功能示例
May 23 Python
django框架实现模板中获取request 的各种信息示例
Jul 01 Python
Python中的self用法详解
Aug 06 Python
使用Python刷淘宝喵币(低阶入门版)
Oct 30 Python
Python如何通过百度翻译API实现翻译功能
Apr 02 Python
python实现PDF中表格转化为Excel的方法
Jun 16 Python
Python模块zipfile原理及使用方法详解
Aug 04 Python
python 实现数据库中数据添加、查询与更新的示例代码
Dec 07 Python
python神经网络 使用Keras构建RNN训练
May 04 Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
You might like
PHP发明人谈MVC和网站设计架构 貌似他不支持php用mvc
2011/06/04 PHP
php代码中使用换行及(\n或\r\n和br)的应用
2013/02/02 PHP
朋友网关于QQ相关的PHP代码(研究QQ的绝佳资料)
2015/01/26 PHP
PHP实现的迪科斯彻(Dijkstra)最短路径算法实例
2017/09/16 PHP
最佳JS代码编写的14条技巧
2011/01/09 Javascript
用jquery存取照片的具体实现方法
2013/06/30 Javascript
JavaScript fontsize方法入门实例(按照指定的尺寸来显示字符串)
2014/10/17 Javascript
JS中对象与字符串的互相转换详解
2016/05/20 Javascript
BootStrap modal模态弹窗使用小结
2016/10/26 Javascript
bootstrap基础知识学习笔记
2016/11/02 Javascript
jQuery拖拽通过八个点改变div大小
2020/11/29 Javascript
JavaScript DOM节点操作实例小结(新建,删除HTML元素)
2017/01/19 Javascript
微信小程序 页面跳转及数据传递详解
2017/03/14 Javascript
ES6新特性之变量和字符串用法示例
2017/04/01 Javascript
webpack 2的react开发配置实例代码
2017/07/28 Javascript
vue项目中跳转到外部链接的实例讲解
2018/09/20 Javascript
使用vue脚手架(vue-cli)搭建一个项目详解
2019/05/09 Javascript
js实现烟花特效
2020/03/02 Javascript
ant design pro中可控的筛选和排序实例
2020/11/17 Javascript
go和python调用其它程序并得到程序输出
2014/02/10 Python
python的keyword模块用法实例分析
2015/06/30 Python
在Django框架中运行Python应用全攻略
2015/07/17 Python
Python numpy生成矩阵、串联矩阵代码分享
2017/12/04 Python
python使用RNN实现文本分类
2018/05/24 Python
Python3实现的简单三级菜单功能示例
2019/03/12 Python
详解python中@的用法
2019/03/27 Python
python @classmethod 的使用场合详解
2019/08/23 Python
pandas的resample重采样的使用
2020/04/24 Python
Python单元测试及unittest框架用法实例解析
2020/07/09 Python
京东全球售:直邮香港,澳门,台湾,美国,澳大利亚等地区
2017/09/24 全球购物
立志成才演讲稿
2014/09/04 职场文书
检察机关个人对照检查材料
2014/09/15 职场文书
先进事迹材料范文
2014/12/29 职场文书
综合素质评价自我评价
2015/03/06 职场文书
python数字图像处理之图像自动阈值分割示例
2022/06/28 Python
Pytorch中expand()的使用(扩展某个维度)
2022/07/15 Python