TensorFlow学习之分布式的TensorFlow运行环境


Posted in Python onFebruary 05, 2020

当我们在大型的数据集上面进行深度学习的训练时,往往需要大量的运行资源,而且还要花费大量时间才能完成训练。

1.分布式TensorFlow的角色与原理

在分布式的TensorFlow中的角色分配如下:

PS:作为分布式训练的服务端,等待各个终端(supervisors)来连接。

worker:在TensorFlow的代码注释中被称为终端(supervisors),作为分布式训练的计算资源终端。

chief supervisors:在众多的运算终端中必须选择一个作为主要的运算终端。该终端在运算终端中最先启动,它的功能是合并各个终端运算后的学习参数,将其保存或者载入。

每个具体的网络标识都是唯一的,即分布在不同IP的机器上(或者同一个机器的不同端口)。在实际的运行中,各个角色的网络构建部分代码必须100%的相同。三者的分工如下:

服务端作为一个多方协调者,等待各个运算终端来连接。

chief supervisors会在启动时同一管理全局的学习参数,进行初始化或者从模型载入。

其他的运算终端只是负责得到其对应的任务并进行计算,并不会保存检查点,用于TensorBoard可视化中的summary日志等任何参数信息。

在整个过程都是通过RPC协议来进行通信的。

2.分布部署TensorFlow的具体方法

配置过程中,首先建立一个server,在server中会将ps及所有worker的IP端口准备好。接着,使用tf.train.Supervisor中的managed_ssion来管理一个打开的session。session中只是负责运算,而通信协调的事情就都交给supervisor来管理了。

3.部署训练实例

下面开始实现一个分布式训练的网络模型,以线性回归为例,通过3个端口来建立3个终端,分别是一个ps,两个worker,实现TensorFlow的分布式运算。

1. 为每个角色添加IP地址和端口,创建sever,在一台机器上开3个不同的端口,分别代表PS,chief supervisor和worker。角色的名称用strjob_name表示,以ps为例,代码如下:

# 定义IP和端口
strps_hosts = 'localhost:1681'
strworker_hosts = 'localhost:1682,localhost:1683'
# 定义角色名称
strjob_name = 'ps'
task_index = 0
# 将字符串转数组
ps_hosts = strps_hosts.split(',')
worker_hosts = strps_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
# 创建server
server = tf.train.Server({'ps':ps_hosts, 'worker':worker_hosts}, job_name=strjob_name, task_index=task_index)

2为ps角色添加等待函数

ps角色使用server.join函数进行线程挂起,开始接受连续消息。

# ps角色使用join进行等待
if strjob_name == 'ps':
  print("wait")
  server.join()

3.创建网络的结构

与正常的程序不同,在创建网络结构时,使用tf.device函数将全部的节点都放在当前任务下。在tf.device函数中的任务是通过tf.train.replica_device_setter来指定的。在tf.train.replica_device_setter中使用worker_device来定义具体任务名称;使用cluster的配置来指定角色及对应的IP地址,从而实现管理整个任务下的图节点。代码如下:

with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d'%task_index,
                       cluster=cluster_spec)):
  X = tf.placeholder('float')
  Y = tf.placeholder('float')
  # 模型参数
  w = tf.Variable(tf.random_normal([1]), name='weight')
  b = tf.Variable(tf.zeros([1]), name='bias')
  global_step = tf.train.get_or_create_global_step()  # 获取迭代次数
  z = tf.multiply(X, w) + b
  tf.summary('z', z)
  cost = tf.reduce_mean(tf.square(Y - z))
  tf.summary.scalar('loss_function', cost)
  learning_rate = 0.001
  optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, global_step=global_step)
  saver = tf.train.Saver(max_to_keep=1)
  merged_summary_op = tf.summary.merge_all() # 合并所有summary
  init = tf.global_variables_initializer()

4.创建Supercisor,管理session

在tf.train.Supervisor函数中,is_chief表明为是否为chief Supervisor角色,这里将task_index=0的worker设置成chief Supervisor。saver需要将保存检查点的saver对象传入。init_op表示使用初始化变量的函数。

training_epochs = 2000
display_step = 2
sv = tf.train.Supervisor(is_chief=(task_index == 0),# 0号为chief
             logdir='log/spuer/',
             init_op=init,
             summary_op=None,
             saver=saver,
             global_step=global_step,
             save_model_secs=5)
# 连接目标角色创建session
with sv.managed_session(saver.target) as sess:

5迭代训练

session中的内容与以前一样,直接迭代训练即可。由于使用了supervisor管理session,将使用sv.summary_computed函数来保存summary文件。

print('sess ok')
  print(global_step.eval(session=sess))
  for epoch in range(global_step.eval(session=sess), training_epochs*len(train_x)):
    for (x, y) in zip(train_x, train_y):
      _, epoch = sess.run([optimizer, global_step], feed_dict={X: x, Y: y})
      summary_str = sess.run(merged_summary_op, feed_dict={X: x, Y: y})
      sv.summary_computed(sess, summary_str, global_step=epoch)
      if epoch % display_step == 0:
        loss = sess.run(cost, feed_dict={X:train_x, Y:train_y})
        print("Epoch:", epoch+1, 'loss:', loss, 'W=', sess.run(w), w, 'b=', sess.run(b))
  print(' finished ')
  sv.saver.save(sess, 'log/linear/' + "sv.cpk", global_step=epoch)
sv.stop()

(1)在设置自动保存检查点文件后,手动保存仍然有效,

(2)在运行一半后,在运行supervisor时会自动载入模型的参数,不需要手动调用restore。

(3)在session中不需要进行初始化的操作。

6.建立worker文件

新建两个py文件,设置task_index分别为0和1,其他的部分和上述的代码相一致。

strjob_name = 'worker'
task_index = 1
strjob_name = 'worker'
task_index = 0

7.运行

我们分别启动写好的三个文件,在运行结果中,我们可以看到循环的次数不是连续的,显示结果中会有警告,这是因为在构建supervisor时没有填写local_init_op参数,该参数的含义是在创建worker实例时,初始化本地变量,上述代码中没有设置,系统会自动初始化,并给出警告提示。

分布运算的目的是为了提高整体运算速度,如果同步epoch的准确率需要牺牲总体运行速度为代价,自然很不合适。

在ps的文件中,它只是负责连接,并不参与运算。

总结

以上所述是小编给大家介绍的TensorFlow学习之分布式的TensorFlow运行环境,希望对大家有所帮助!!

Python 相关文章推荐
用Python生成器实现微线程编程的教程
Apr 13 Python
在Django的URLconf中进行函数导入的方法
Jul 18 Python
python模拟登录并且保持cookie的方法详解
Apr 04 Python
Python去除、替换字符串空格的处理方法
Apr 01 Python
Python装饰器原理与用法分析
Apr 30 Python
详解python如何在django中为用户模型添加自定义权限
Oct 15 Python
python实现n个数中选出m个数的方法
Nov 13 Python
python获取本机所有IP地址的方法
Dec 26 Python
python通过tcp发送xml报文的方法
Dec 28 Python
PyQt5 窗口切换与自定义对话框的实例
Jun 20 Python
scrapy爬虫:scrapy.FormRequest中formdata参数详解
Apr 30 Python
python获得命令行输入的参数的两种方式
Nov 02 Python
TensorFlow MNIST手写数据集的实现方法
Feb 05 #Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
You might like
php自定义加密与解密程序实例
2014/12/31 PHP
PHP中4种常用的抓取网络数据方法
2015/06/04 PHP
php字符串比较函数用法小结(strcmp,strcasecmp,strnatcmp及strnatcasecmp)
2016/07/18 PHP
yii2.0数据库迁移教程【多个数据库同时同步数据】
2016/10/08 PHP
jquery中子元素和后代元素的区别示例介绍
2014/04/02 Javascript
JavaScript实现找出数组中最长的连续数字序列
2014/09/03 Javascript
JavaScript和JQuery的鼠标mouse事件冒泡处理
2015/06/19 Javascript
今天抽时间给大家整理jquery和ajax的相关知识
2015/11/17 Javascript
给before和after伪元素设置js效果的方法
2015/12/04 Javascript
jquery实现左右无缝轮播图
2020/07/31 Javascript
详解Jquery EasyUI tree 的异步加载(遍历指定文件夹,根据文件夹内的文件生成tree)
2017/02/11 Javascript
基于pako.js实现gzip的压缩和解压功能示例
2017/06/13 Javascript
解决Vue2.0自带浏览器里无法打开的原因(兼容处理)
2017/07/28 Javascript
基于JavaScript中标识符的命名规则介绍
2018/01/06 Javascript
微信小程序仿美团城市选择
2018/06/06 Javascript
vue+webpack中配置ESLint
2018/11/07 Javascript
ES6顶层对象、global对象实例分析
2019/06/14 Javascript
微信小程序如何自定义table组件
2019/06/29 Javascript
layui radio单选限制下一个radio单选的实例
2019/09/03 Javascript
vue实现浏览器全屏展示功能
2019/11/27 Javascript
使用vue实现通过变量动态拼接url
2020/07/22 Javascript
[01:03:33]Alliance vs TNC 2019国际邀请赛小组赛 BO2 第一场 8.16
2019/08/18 DOTA
Python编写的com组件发生R6034错误的原因与解决办法
2013/04/01 Python
python基础教程之基本内置数据类型介绍
2014/02/20 Python
python脚本作为Windows服务启动代码详解
2018/02/11 Python
Python定义函数实现累计求和操作
2020/05/03 Python
基于python实现地址和经纬度转换
2020/05/19 Python
HTML5操作WebSQL数据库的实例代码
2017/08/26 HTML / CSS
丽笙酒店官方网站:Radisson Hotels
2019/05/07 全球购物
加州风格的游泳和沙滩装品牌:Cupshe
2019/06/10 全球购物
Hotels.com拉丁美洲:从豪华酒店到经济型酒店的预定优惠和折扣
2019/12/09 全球购物
Feelunique中文官网:欧洲最大化妆品零售电商
2020/07/10 全球购物
会议主持人开场白台词
2015/05/28 职场文书
英语投诉信范文
2015/07/03 职场文书
干货:如何写好工作计划!
2019/05/17 职场文书
Python使用pandas导入xlsx格式的excel文件内容操作代码
2022/12/24 Python