TensorFlow实现随机训练和批量训练的方法


Posted in Python onApril 28, 2018

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 随机训练:
# Create graph
sess = tf.Session()

# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加操作到图
my_output = tf.multiply(x_data, A)

# 增加L2损失函数
loss = tf.square(my_output - y_target)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)


# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Step #5 A = [ 1.47604525]
Loss = [ 72.55678558]
Step #10 A = [ 3.01128507]
Loss = [ 48.22986221]
Step #15 A = [ 4.27042341]
Loss = [ 28.97912598]
Step #20 A = [ 5.2984333]
Loss = [ 16.44779968]
Step #25 A = [ 6.17473984]
Loss = [ 16.373312]
Step #30 A = [ 6.89866304]
Loss = [ 11.71054649]
Step #35 A = [ 7.39849901]
Loss = [ 6.42773056]
Step #40 A = [ 7.84618378]
Loss = [ 5.92940331]
Step #45 A = [ 8.15709782]
Loss = [ 0.2142024]
Step #50 A = [ 8.54818344]
Loss = [ 7.11651039]
Step #55 A = [ 8.82354641]
Loss = [ 1.47823763]
Step #60 A = [ 9.07896614]
Loss = [ 3.08244276]
Step #65 A = [ 9.24868107]
Loss = [ 0.01143846]
Step #70 A = [ 9.36772251]
Loss = [ 2.10078788]
Step #75 A = [ 9.49171734]
Loss = [ 3.90913701]
Step #80 A = [ 9.6622715]
Loss = [ 4.80727625]
Step #85 A = [ 9.73786926]
Loss = [ 0.39915398]
Step #90 A = [ 9.81853104]
Loss = [ 0.14876099]
Step #95 A = [ 9.90371323]
Loss = [ 0.01657014]
Step #100 A = [ 9.86669159]
Loss = [ 0.444787]
Step #5 A = [[ 2.34371352]]
Loss = 58.766
Step #10 A = [[ 3.74766445]]
Loss = 38.4875
Step #15 A = [[ 4.88928795]]
Loss = 27.5632
Step #20 A = [[ 5.82038736]]
Loss = 17.9523
Step #25 A = [[ 6.58999157]]
Loss = 13.3245
Step #30 A = [[ 7.20851326]]
Loss = 8.68099
Step #35 A = [[ 7.71694899]]
Loss = 4.60659
Step #40 A = [[ 8.1296711]]
Loss = 4.70107
Step #45 A = [[ 8.47107315]]
Loss = 3.28318
Step #50 A = [[ 8.74283409]]
Loss = 1.99057
Step #55 A = [[ 8.98811722]]
Loss = 2.66906
Step #60 A = [[ 9.18062305]]
Loss = 3.26207
Step #65 A = [[ 9.31655025]]
Loss = 2.55459
Step #70 A = [[ 9.43130589]]
Loss = 1.95839
Step #75 A = [[ 9.55670166]]
Loss = 1.46504
Step #80 A = [[ 9.6354847]]
Loss = 1.49021
Step #85 A = [[ 9.73470974]]
Loss = 1.53289
Step #90 A = [[ 9.77956581]]
Loss = 1.52173
Step #95 A = [[ 9.83666706]]
Loss = 0.819207
Step #100 A = [[ 9.85569191]]
Loss = 1.2197

TensorFlow实现随机训练和批量训练的方法

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x和3.x下maketrans与translate函数使用上的不同
Apr 13 Python
利用Python实现网络测试的脚本分享
May 26 Python
基于python爬虫数据处理(详解)
Jun 10 Python
可能是最全面的 Python 字符串拼接总结【收藏】
Jul 09 Python
python调用matplotlib模块绘制柱状图
Oct 18 Python
Python通过Manager方式实现多个无关联进程共享数据的实现
Nov 07 Python
python实现异常信息堆栈输出到日志文件
Dec 26 Python
Java如何基于wsimport调用wcf接口
Jun 17 Python
Python+Dlib+Opencv实现人脸采集并表情判别功能的代码
Jul 01 Python
Python tensorflow卷积神经Inception V3网络结构
May 06 Python
Python可视化神器pyecharts之绘制地理图表练习
Jul 07 Python
Python TypeError: ‘float‘ object is not subscriptable错误解决
Dec 24 Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
You might like
PHP 5.3和PHP 5.4出现FastCGI Error解决方法
2015/02/12 PHP
php实现随机生成易于记忆的密码
2015/06/19 PHP
thinkPHP中多维数组的遍历方法
2016/01/09 PHP
php实现的网页版剪刀石头布游戏示例
2016/11/25 PHP
用js统计用户下载网页所需时间的脚本
2008/10/15 Javascript
JS input文本框禁用右键和复制粘贴功能的代码
2010/04/15 Javascript
基于jquery的代码显示区域自动拉长效果
2011/12/07 Javascript
jquery组件使用中遇到的问题整理及解决
2014/02/21 Javascript
jquery实现刷新随机变化样式特效(tag标签样式)
2017/02/03 Javascript
AngularJS学习第一篇 AngularJS基础知识
2017/02/13 Javascript
Angularjs自定义指令实现三级联动 选择地理位置
2017/02/13 Javascript
JavaScript实现开关等效果
2017/09/08 Javascript
Nodejs实现文件上传的示例代码
2017/09/26 NodeJs
JavaScript判断变量名是否存在数组中的实例
2017/12/28 Javascript
nodejs用gulp管理前端文件方法
2018/06/24 NodeJs
js实现贪吃蛇小游戏(加墙)
2020/07/31 Javascript
[02:03]《现实生活中的DOTA2》—林书豪&DOTA2职业选手出演短片
2015/08/18 DOTA
[52:22]EG vs VG Supermajor小组赛B组 BO3 第一场 6.2
2018/06/03 DOTA
[46:40]VGJ.T vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
Python中基础的socket编程实战攻略
2016/06/01 Python
python http接口自动化脚本详解
2018/01/02 Python
对python中dict和json的区别详解
2018/12/18 Python
Python Matplotlib实现三维数据的散点图绘制
2019/03/19 Python
python使用mitmproxy抓取浏览器请求的方法
2019/07/02 Python
Python Django框架防御CSRF攻击的方法分析
2019/10/18 Python
在pycharm中创建django项目的示例代码
2020/05/28 Python
Python使用socket模块实现简单tcp通信
2020/08/18 Python
五款漂亮的纯CSS3动画按钮的实例教程
2014/11/21 HTML / CSS
巴西图书和电子产品购物网站:Saraiva
2017/06/07 全球购物
小学教师节活动方案
2014/01/31 职场文书
古汉语文学求职信范文
2014/03/16 职场文书
共产党员公开承诺践诺书
2014/05/28 职场文书
领导班子“四风问题”“整改方案
2014/10/02 职场文书
大学文艺委员竞选稿
2015/11/19 职场文书
利用javaScript处理常用事件详解
2021/04/14 Javascript
Django利用AJAX技术实现博文实时搜索
2021/05/06 Python