TensorFlow实现Batch Normalization


Posted in Python onMarch 08, 2018

一、BN(Batch Normalization)算法

1. 对数据进行归一化处理的重要性

神经网络学习过程的本质就是学习数据分布,在训练数据与测试数据分布不同情况下,模型的泛化能力就大大降低;另一方面,若训练过程中每批batch的数据分布也各不相同,那么网络每批迭代学习过程也会出现较大波动,使之更难趋于收敛,降低训练收敛速度。对于深层网络,网络前几层的微小变化都会被网络累积放大,则训练数据的分布变化问题会被放大,更加影响训练速度。

2. BN算法的强大之处

1)为了加速梯度下降算法的训练,我们可以采取指数衰减学习率等方法在初期快速学习,后期缓慢进入全局最优区域。使用BN算法后,就可以直接选择比较大的学习率,且设置很大的学习率衰减速度,大大提高训练速度。即使选择了较小的学习率,也会比以前不使用BN情况下的收敛速度快。总结就是BN算法具有快速收敛的特性。

2)BN具有提高网络泛化能力的特性。采用BN算法后,就可以移除针对过拟合问题而设置的dropout和L2正则化项,或者采用更小的L2正则化参数。

3)BN本身是一个归一化网络层,则局部响应归一化层(Local Response Normalization,LRN层)则可不需要了(Alexnet网络中使用到)。

3. BN算法概述

BN算法提出了变换重构,引入了可学习参数γ、β,这就是算法的关键之处:

TensorFlow实现Batch Normalization

引入这两个参数后,我们的网络便可以学习恢复出原是网络所要学习的特征分布,BN层的钱箱传到过程如下:

TensorFlow实现Batch Normalization

其中m为batchsize。BatchNormalization中所有的操作都是平滑可导,这使得back propagation可以有效运行并学到相应的参数γ,β。需要注意的一点是Batch Normalization在training和testing时行为有所差别。Training时μβ和σβ由当前batch计算得出;在Testing时μβ和σβ应使用Training时保存的均值或类似的经过处理的值,而不是由当前batch计算。

二、TensorFlow相关函数

1.tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)

x是输入张量,axes是在哪个维度上求解, 即想要 normalize的维度, [0] 代表 batch 维度,如果是图像数据,可以传入 [0, 1, 2],相当于求[batch, height, width] 的均值/方差,注意不要加入channel 维度。该函数返回两个张量,均值mean和方差variance。

2.tf.identity(input, name=None)

返回与输入张量input形状和内容一致的张量。

3.tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon,name=None)

计算公式为scale(x - mean)/ variance + offset。

这些参数中,tf.nn.moments可得到均值mean和方差variance,offset和scale是可训练的,offset一般初始化为0,scale初始化为1,offset和scale的shape与mean相同,variance_epsilon参数设为一个很小的值如0.001。

三、TensorFlow代码实现

1. 完整代码

import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
 
ACTIVITION = tf.nn.relu 
N_LAYERS = 7 # 总共7层隐藏层 
N_HIDDEN_UNITS = 30 # 每层包含30个神经元 
 
def fix_seed(seed=1): # 设置随机数种子 
  np.random.seed(seed) 
  tf.set_random_seed(seed) 
 
def plot_his(inputs, inputs_norm): # 绘制直方图函数 
  for j, all_inputs in enumerate([inputs, inputs_norm]): 
    for i, input in enumerate(all_inputs): 
      plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1)) 
      plt.cla() 
      if i == 0: 
        the_range = (-7, 10) 
      else: 
        the_range = (-1, 1) 
      plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733') 
      plt.yticks(()) 
      if j == 1: 
        plt.xticks(the_range) 
      else: 
        plt.xticks(()) 
      ax = plt.gca() 
      ax.spines['right'].set_color('none') 
      ax.spines['top'].set_color('none') 
    plt.title("%s normalizing" % ("Without" if j == 0 else "With")) 
  plt.draw() 
  plt.pause(0.01) 
 
def built_net(xs, ys, norm): # 搭建网络函数 
  # 添加层 
  def add_layer(inputs, in_size, out_size, activation_function=None, norm=False): 
    Weights = tf.Variable(tf.random_normal([in_size, out_size], 
                        mean=0.0, stddev=1.0)) 
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) 
    Wx_plus_b = tf.matmul(inputs, Weights) + biases 
 
    if norm: # 判断是否是Batch Normalization层 
      # 计算均值和方差,axes参数0表示batch维度 
      fc_mean, fc_var = tf.nn.moments(Wx_plus_b, axes=[0]) 
      scale = tf.Variable(tf.ones([out_size])) 
      shift = tf.Variable(tf.zeros([out_size])) 
      epsilon = 0.001 
 
      # 定义滑动平均模型对象 
      ema = tf.train.ExponentialMovingAverage(decay=0.5) 
 
      def mean_var_with_update(): 
        ema_apply_op = ema.apply([fc_mean, fc_var]) 
        with tf.control_dependencies([ema_apply_op]): 
          return tf.identity(fc_mean), tf.identity(fc_var) 
 
      mean, var = mean_var_with_update() 
 
      Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, 
                         shift, scale, epsilon) 
 
    if activation_function is None: 
      outputs = Wx_plus_b 
    else: 
      outputs = activation_function(Wx_plus_b) 
    return outputs 
 
  fix_seed(1) 
 
  if norm: # 为第一层进行BN 
    fc_mean, fc_var = tf.nn.moments(xs, axes=[0]) 
    scale = tf.Variable(tf.ones([1])) 
    shift = tf.Variable(tf.zeros([1])) 
    epsilon = 0.001 
 
    ema = tf.train.ExponentialMovingAverage(decay=0.5) 
 
    def mean_var_with_update(): 
      ema_apply_op = ema.apply([fc_mean, fc_var]) 
      with tf.control_dependencies([ema_apply_op]): 
        return tf.identity(fc_mean), tf.identity(fc_var) 
 
    mean, var = mean_var_with_update() 
    xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon) 
 
  layers_inputs = [xs] # 记录每一层的输入 
 
  for l_n in range(N_LAYERS): # 依次添加7层 
    layer_input = layers_inputs[l_n] 
    in_size = layers_inputs[l_n].get_shape()[1].value 
 
    output = add_layer(layer_input, in_size, N_HIDDEN_UNITS, ACTIVITION, norm) 
    layers_inputs.append(output) 
 
  prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None) 
  cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), 
                    reduction_indices=[1])) 
 
  train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost) 
  return [train_op, cost, layers_inputs] 
 
fix_seed(1) 
x_data = np.linspace(-7, 10, 2500)[:, np.newaxis] 
np.random.shuffle(x_data) 
noise =np.random.normal(0, 8, x_data.shape) 
y_data = np.square(x_data) - 5 + noise 
 
plt.scatter(x_data, y_data) 
plt.show() 
 
xs = tf.placeholder(tf.float32, [None, 1]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
 
train_op, cost, layers_inputs = built_net(xs, ys, norm=False) 
train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
 
  cost_his = [] 
  cost_his_norm = [] 
  record_step = 5 
 
  plt.ion() 
  plt.figure(figsize=(7, 3)) 
  for i in range(250): 
    if i % 50 == 0: 
      all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], 
                          feed_dict={xs: x_data, ys: y_data}) 
      plot_his(all_inputs, all_inputs_norm) 
 
    sess.run([train_op, train_op_norm], 
         feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]}) 
 
    if i % record_step == 0: 
      cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data})) 
      cost_his_norm.append(sess.run(cost_norm, 
                     feed_dict={xs: x_data, ys: y_data})) 
 
  plt.ioff() 
  plt.figure() 
  plt.plot(np.arange(len(cost_his))*record_step, 
       np.array(cost_his), label='Without BN')   # no norm 
  plt.plot(np.arange(len(cost_his))*record_step, 
       np.array(cost_his_norm), label='With BN')  # norm 
  plt.legend() 
  plt.show()

2. 实验结果

输入数据分布:

TensorFlow实现Batch Normalization

批标准化BN效果对比:

TensorFlow实现Batch Normalization

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python切换pip安装源的方法详解
Nov 18 Python
hmac模块生成加入了密钥的消息摘要详解
Jan 11 Python
使用Python通过win32 COM实现Word文档的写入与保存方法
May 08 Python
Python3处理HTTP请求的实例
May 10 Python
matplotlib.pyplot画图 图片的二进制流的获取方法
May 24 Python
解决python3 pika之连接断开的问题
Dec 18 Python
深入浅析Python 中 is 语法带来的误解
May 07 Python
Python爬取破解无线网络wifi密码过程解析
Sep 17 Python
简单了解python中的f.b.u.r函数
Nov 02 Python
django框架中ajax的使用及避开CSRF 验证的方式详解
Dec 11 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
python如何修改文件时间属性
Feb 05 Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
TensorFlow模型保存和提取的方法
Mar 08 #Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
TensorFlow模型保存/载入的两种方法
Mar 08 #Python
You might like
PHP 高级课程笔记 面向对象
2009/06/21 PHP
phplock(php进程锁) v1.0 beta1
2009/11/24 PHP
ThinkPHP使用PHPExcel实现Excel数据导入导出完整实例
2014/07/22 PHP
php中数字、字符与对象判断函数用法实例
2014/11/26 PHP
php each 返回数组中当前的键值对并将数组指针向前移动一步实例
2016/11/22 PHP
Thinkphp5.0框架的Db操作实例分析【连接、增删改查、链式操作等】
2019/10/11 PHP
ThinkPHP 框架实现的读取excel导入数据库操作示例
2020/04/14 PHP
php7 新增功能实例总结
2020/05/25 PHP
页面中js执行顺序
2009/11/09 Javascript
jquery.validate使用攻略 第五步 正则验证
2010/07/01 Javascript
jquery实现邮箱自动补全功能示例分享
2014/02/17 Javascript
setTimeout()递归调用不加引号出错的解决方法
2014/09/05 Javascript
jQuery在ul中显示某个li索引号的方法
2015/03/17 Javascript
JavaScript实现自动消除按钮功能的方法
2015/08/05 Javascript
AngularJS 表达式详解及实例代码
2016/09/14 Javascript
详解vue-cli与webpack结合如何处理静态资源
2017/09/19 Javascript
微信小程序实现MUI数字输入框效果
2018/01/31 Javascript
vue二级路由设置方法
2018/02/09 Javascript
JavaScript实现简单轮播图效果
2018/12/01 Javascript
微信小程序实现消息框弹出动画
2020/04/18 Javascript
小程序实现悬浮搜索框
2019/07/12 Javascript
JS中数组实现代码(倒序遍历数组,数组连接字符串)
2019/12/29 Javascript
Python异常处理总结
2014/08/15 Python
python分割文件的常用方法
2014/11/01 Python
scrapy自定义pipeline类实现将采集数据保存到mongodb的方法
2015/04/16 Python
详解Python中的正斜杠与反斜杠
2019/08/09 Python
解决Python spyder显示不全df列和行的问题
2020/04/20 Python
python中pdb模块实例用法
2021/01/15 Python
LightInTheBox西班牙站点:全球商品在线采购
2016/09/22 全球购物
美国用餐电影院:Alamo Drafthouse Cinema
2020/01/23 全球购物
SmartBuyGlasses荷兰:购买太阳镜和眼镜
2020/03/16 全球购物
报关简历自我评价怎么写
2013/09/19 职场文书
预备党员承诺书
2014/03/25 职场文书
暖通工程师岗位职责
2014/06/12 职场文书
人民检察院起诉书
2015/05/20 职场文书
Python使用Opencv打开笔记本电脑摄像头报错解问题及解决
2022/06/21 Python