学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践

清醒疯子 发布于 1周前
无人欣赏。

分布式TensorFlow由高性能gRPC库底层技术支持。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。

分布式原理。分布式集群 由多个服务器进程、客户端进程组成。部署方式,单机多卡、分布式(多机多卡)。多机多卡TensorFlow分布式。

单机多卡,单台服务器多块GPU。训练过程:在单机单GPU训练,数据一个批次(batch)一个批次训练。单机多GPU,一次处理多个批次数据,每个GPU处理一个批次数据计算。变量参数保存在CPU,数据由CPU分发给多个GPU,GPU计算每个批次更新梯度。CPU收集完多个GPU更新梯度,计算平均梯度,更新参数。继续计算更新梯度。处理速度取决最慢GPU速度。

分布式,训练在多个工作节点(worker)。工作节点,实现计算单元。计算服务器单卡,指服务器。计算服务器多卡,多个GPU划分多个工作节点。数据量大,超过一台机器处理能力,须用分布式。

分布式TensorFlow底层通信,gRPC(google remote procedure call)。gRPC,谷歌开源高性能、跨语言RPC框架。RPC协议,远程过程调用协议,网络从远程计算机程度请求服务。

分布式部署方式。分布式运行,多个计算单元(工作节点),后端服务器部署单工作节点、多工作节点。

单工作节点部署。每台服务器运行一个工作节点,服务器多个GPU,一个工作节点可以访问多块GPU卡。代码tf.device()指定运行操作设备。优势,单机多GPU间通信,效率高。劣势,手动代码指定设备。

多工作节点部署。一台服务器运行多个工作节点。

设置CUDAVISIBLEDEVICES环境变量,限制各个工作节点只可见一个GPU,启动进程添加环境变量。用tf.device()指定特定GPU。多工作节点部署优势,代码简单,提高GPU使用率。劣势,工作节点通信,需部署多个工作节点。https://github.com/tobegit3hub/tensorflowexamples/tree/master/distributedtensorflow 。

CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1

分布式架构。https://www.tensorflow.org/extend/architecture 。客户端(client)、服务端(server),服务端包括主节点(master)、工作节点(worker)组成。

客户端、主节点、工作节点关系。TensorFlow,客户端会话联系主节点,实际工作由工作节点实现,每个工作节点占一台设备(TensorFlow具体计算硬件抽象,CPU或GPU)。单机模式,客户端、主节点、工作节点在同一台服务器。分布模式,可不同服务器。客户端->主节点->工作节点/job:worker/task:0->/job:ps/task:0。 客户端。建立TensorFlow计算图,建立与集群交互会话层。代码包含Session()。一个客户端可同时与多个服务端相连,一具服务端也可与多个客户端相连。 服务端。运行tf.train.Server实例进程,TensroFlow执行任务集群(cluster)一部分。有主节点服务(Master service)和工作节点服务(Worker service)。运行中,一个主节点进程和数个工作节点进程,主节点进程和工作接点进程通过接口通信。单机多卡和分布式结构相同,只需要更改通信接口实现切换。 主节点服务。实现tensorflow::Session接口。通过RPC服务程序连接工作节点,与工作节点服务进程工作任务通信。TensorFlow服务端,taskindex为0作业(job)。 工作节点服务。实现workerservice.proto接口,本地设备计算部分图。TensorFlow服务端,所有工作节点包含工作节点服务逻辑。每个工作节点负责管理一个或多个设备。工作节点可以是本地不同端口不同进程,或多台服务多个进程。运行TensorFlow分布式执行任务集,一个或多个作业(job)。每个作业,一个或多个相同目的任务(task)。每个任务,一个工作进程执行。作业是任务集合,集群是作业集合。 分布式机器学习框架,作业分参数作业(parameter job)和工作节点作业(worker job)。参数作业运行服务器为参数服务器(parameter server,PS),管理参数存储、更新。工作节点作业,管理无状态主要从事计算任务。模型越大,参数越多,模型参数更新超过一台机器性能,需要把参数分开到不同机器存储更新。参数服务,多台机器组成集群,类似分布式存储架构,涉及数据同步、一致性,参数存储为键值对(key-value)。分布式键值内存数据库,加参数更新操作。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/file/ps.pdf 。 参数存储更新在参数作业进行,模型计算在工作节点作业进行。TensorFlow分布式实现作业间数据传输,参数作业到工作节点作业前向传播,工作节点作业到参数作业反向传播。 任务。特定TensorFlow服务器独立进程,在作业中拥有对应序号。一个任务对应一个工作节点。集群->作业->任务->工作节点。

客户端、主节点、工作节点交互过程。单机多卡交互,客户端->会话运行->主节点->执行子图->工作节点->GPU0、GPU1。分布式交互,客户端->会话运行->主节点进程->执行子图1->工作节点进程1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》https://arxiv.org/abs/1603.04467v1 。

分布式模式。

数据并行。https://www.tensorflow.org/tutorials/deep_cnn 。CPU负责梯度平均、参数更新,不同GPU训练模型副本(model replica)。基于训练样例子集训练,模型有独立性。 步骤:不同GPU分别定义模型网络结构。单个GPU从数据管道读取不同数据块,前向传播,计算损失,计算当前变量梯度。所有GPU输出梯度数据转移到CPU,梯度求平均操作,模型变量更新。重复,直到模型变量收敛。 数据并行,提高SGD效率。SGD mini-batch样本,切成多份,模型复制多份,在多个模型上同时计算。多个模型计算速度不一致,CPU更新变量有同步、异步两个方案。

同步更新、异步更新。分布式随机梯度下降法,模型参数分布式存储在不同参数服务上,工作节点并行训练数据,和参数服务器通信获取模型参数。 同步随机梯度下降法(Sync-SGD,同步更新、同步训练),训练时,每个节点上工作任务读入共享参数,执行并行梯度计算,同步需要等待所有工作节点把局部梯度处好,将所有共享参数合并、累加,再一次性更新到模型参数,下一批次,所有工作节点用模型更新后参数训练。优势,每个训练批次考虑所有工作节点训练情部,损失下降稳定。劣势,性能瓶颈在最慢工作节点。异楹设备,工作节点性能不同,劣势明显。 异步随机梯度下降法(Async-SGD,异步更新、异步训练),每个工作节点任务独立计算局部梯度,异步更新到模型参数,不需执行协调、等待操作。优势,性能不存在瓶颈。劣势,每个工作节点计算梯度值发磅回参数服务器有参数更新冲突,影响算法收剑速度,损失下降过程抖动较大。 同步更新、异步更新实现区别于更新参数服务器参数策略。数据量小,各节点计算能力较均衡,用同步模型。数据量大,各机器计算性能参差不齐,用异步模式。 带备份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz论文《Revisiting Distributed Synchronous SGD》https://arxiv.org/abs/1604.00981 。增加工作节点,解决部分工作节点计算慢问题。工作节点总数n+n*5%,n为集群工作节点数。异步更新设定接受到n个工作节点参数直接更新参数服务器模型参数,进入下一批次模型训练。计算较慢节点训练参数直接丢弃。 同步更新、异步更新有图内模式(in-graph pattern)和图间模式(between-graph pattern),独立于图内(in-graph)、图间(between-graph)概念。 图内复制(in-grasph replication),所有操作(operation)在同一个图中,用一个客户端来生成图,把所有操作分配到集群所有参数服务器和工作节点上。国内复制和单机多卡类似,扩展到多机多卡,数据分发还是在客户端一个节点上。优势,计算节点只需要调用join()函数等待任务,客户端随时提交数据就可以训练。劣势,训练数据分发在一个节点上,要分发给不同工作节点,严重影响并发训练速度。 图间复制(between-graph replication),每一个工作节点创建一个图,训练参数保存在参数服务器,数据不分发,各个工作节点独立计算,计算完成把要更新参数告诉参数服务器,参数服务器更新参数。优势,不需要数据分发,各个工作节点都创建图和读取数据训练。劣势,工作节点既是图创建者又是计算任务执行者,某个工作节点宕机影响集群工作。大数据相关深度学习推荐使用图间模式。

模型并行。切分模型,模型不同部分执行在不同设备上,一个批次样本可以在不同设备同时执行。TensorFlow尽量让相邻计算在同一台设备上完成节省网络开销。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》https://arxiv.org/abs/1603.04467v1 。

模型并行、数据并行,TensorFlow中,计算可以分离,参数可以分离。可以在每个设备上分配计算节点,让对应参数也在该设备上,计算参数放一起。

分布式API。https://www.tensorflow.org/deploy/distributed 。 创建集群,每个任务(task)启动一个服务(工作节点服务或主节点服务)。任务可以分布不同机器,可以同一台机器启动多个任务,用不同GPU运行。每个任务完成工作:创建一个tf.train.ClusterSpec,对集群所有任务进行描述,描述内容对所有任务相同。创建一个tf.train.Server,创建一个服务,运行相应作业计算任务。 TensorFlow分布式开发API。tf.train.ClusterSpec({"ps":pshosts,"worker":workehosts})。创建TensorFlow集群描述信息,ps、worker为作业名称,psphsts、workerhosts为作业任务所在节点地址信息。tf.train.ClusterSpec传入参数,作业和任务间关系映射,映射关系任务通过IP地址、端口号表示。

结构 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
可用任务 /job:local/task:0、/job:local/task:1。
结构 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
可用任务 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1

tf.train.Server(cluster,jobname,taskindex)。创建服务(主节点服务或工作节点服务),运行作业计算任务,运行任务在task_index指定机器启动。

#任务0 
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server  = tr.train.Server(cluster,job_name="local",task_index=0) 
#任务1 
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server  = tr.train.Server(cluster,job_name="local",task_index=1)。

自动化管理节点、监控节点工具。集群管理工具Kubernetes。 tf.device(devicenameor_function)。设定指定设备执行张量运算,批定代码运行CPU、GPU。

#指定在task0所在机器执行Tensor操作运算 
with tf.device("/job:ps/task:0"):
  weights_1 = tf.Variable(…)
  biases_1 = tf.Variable(…)

分布式训练代码框架。创建TensorFlow服务器集群,在该集群分布式计算数据流图。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/deploy/distributed.md 。

import argparse
import sys
import tensorflow as tf
FLAGS = None
def main(_):
  # 第1步:命令行参数解析,获取集群信息ps_hosts、worker_hosts
  # 当前节点角色信息job_name、task_index
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")
  # 第2步:创建当前任务节点服务器
  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)
  # 第3步:如果当前节点是参数服务器,调用server.join()无休止等待;如果是工作节点,执行第4步
  if FLAGS.job_name == "ps":
    server.join()
  # 第4步:构建要训练模型,构建计算图
  elif FLAGS.job_name == "worker":
    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):
      # Build model...
      loss = ...
      global_step = tf.contrib.framework.get_or_create_global_step()
      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)
    # The StopAtStepHook handles stopping after running given steps.
    # 第5步管理模型训练过程
    hooks=[tf.train.StopAtStepHook(last_step=1000000)]
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(FLAGS.task_index == 0),
                                           checkpoint_dir="/tmp/train_logs",
                                           hooks=hooks) as mon_sess:
      while not mon_sess.should_stop():
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.
        # mon_sess.run handles AbortedError in case of preempted PS.
        # 训练模型
        mon_sess.run(train_op)
if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  # Flags for defining the tf.train.ClusterSpec
  parser.add_argument(
      "--ps_hosts",
      type=str,
      default="",
      help="Comma-separated list of hostname:port pairs"
  )
  parser.add_argument(
      "--worker_hosts",
      type=str,
      default="",
      help="Comma-separated list of hostname:port pairs"
  )
  parser.add_argument(
      "--job_name",
      type=str,
      default="",
      help="One of 'ps', 'worker'"
  )
  # Flags for defining the tf.train.Server
  parser.add_argument(
      "--task_index",
      type=int,
      default=0,
      help="Index of task within the job"
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

分布式最佳实践。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/disttest/python/mnistreplica.py 。 MNIST数据集分布式训练。开设3个端口作分布式工作节点部署,2222端口参数服务器,2223端口工作节点0,2224端口工作节点1。参数服务器执行参数更新任务,工作节点0、工作节点1执行图模型训练计算任务。参数服务器/job:ps/task:0 cocalhost:2222,工作节点/job:worker/task:0 cocalhost:2223,工作节点/job:worker/task:1 cocalhost:2224。 运行代码。

python mnist_replica.py --job_name="ps" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=0
python mnist_replica.py --job_name="worker" --task_index=1

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 定义常量,用于创建数据流图
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                    "Directory for storing mnist data")
# 只下载数据,不做其他操作
flags.DEFINE_boolean("download_only", False,
                     "Only perform downloading of data; Do not proceed to "
                     "session preparation, model definition or training")
# task_index从0开始。0代表用来初始化变量的第一个任务
flags.DEFINE_integer("task_index", None,
                     "Worker task index, should be >= 0. task_index=0 is "
                     "the master worker task the performs the variable "
                     "initialization ")
# 每台机器GPU个数,机器没有GPU为0
flags.DEFINE_integer("num_gpus", 1,
                     "Total number of gpus for each machine."
                     "If you don't use GPU, please set it to '0'")
# 同步训练模型下,设置收集工作节点数量。默认工作节点总数
flags.DEFINE_integer("replicas_to_aggregate", None,
                     "Number of replicas to aggregate before parameter update"
                     "is applied (For sync_replicas mode only; default: "
                     "num_workers)")
flags.DEFINE_integer("hidden_units", 100,
                     "Number of units in the hidden layer of the NN")
# 训练次数
flags.DEFINE_integer("train_steps", 200,
                     "Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
# 使用同步训练、异步训练
flags.DEFINE_boolean("sync_replicas", False,
                     "Use the sync_replicas (synchronized replicas) mode, "
                     "wherein the parameter updates from workers are aggregated "
                     "before applied to avoid stale gradients")
# 如果服务器已经存在,采用gRPC协议通信;如果不存在,采用进程间通信
flags.DEFINE_boolean(
    "existing_servers", False, "Whether servers already exists. If True, "
    "will use the worker hosts via their GRPC URLs (one client process "
    "per worker host). Otherwise, will create an in-process TensorFlow "
    "server.")
# 参数服务器主机
flags.DEFINE_string("ps_hosts","localhost:2222",
                    "Comma-separated list of hostname:port pairs")
# 工作节点主机
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
                    "Comma-separated list of hostname:port pairs")
# 本作业是工作节点还是参数服务器
flags.DEFINE_string("job_name", None,"job name: worker or ps")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(unused_argv):
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
  if FLAGS.download_only:
    sys.exit(0)
  if FLAGS.job_name is None or FLAGS.job_name == "":
    raise ValueError("Must specify an explicit `job_name`")
  if FLAGS.task_index is None or FLAGS.task_index =="":
    raise ValueError("Must specify an explicit `task_index`")
  print("job name = %s" % FLAGS.job_name)
  print("task index = %d" % FLAGS.task_index)
  #Construct the cluster and start the server
  # 读取集群描述信息
  ps_spec = FLAGS.ps_hosts.split(",")
  worker_spec = FLAGS.worker_hosts.split(",")
  # Get the number of workers.
  num_workers = len(worker_spec)
  # 创建TensorFlow集群描述对象
  cluster = tf.train.ClusterSpec({
      "ps": ps_spec,
      "worker": worker_spec})
  # 为本地执行任务创建TensorFlow Server对象。
  if not FLAGS.existing_servers:
    # Not using existing servers. Create an in-process server.
    # 创建本地Sever对象,从tf.train.Server这个定义开始,每个节点开始不同
    # 根据执行的命令的参数(作业名字)不同,决定这个任务是哪个任务
    # 如果作业名字是ps,进程就加入这里,作为参数更新的服务,等待其他工作节点给它提交参数更新的数据
    # 如果作业名字是worker,就执行后面的计算任务
    server = tf.train.Server(
        cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
    # 如果是参数服务器,直接启动即可。这里,进程就会阻塞在这里
    # 下面的tf.train.replica_device_setter代码会将参数批定给ps_server保管
    if FLAGS.job_name == "ps":
      server.join()
  # 处理工作节点
  # 找出worker的主节点,即task_index为0的点
  is_chief = (FLAGS.task_index == 0)
  # 如果使用gpu
  if FLAGS.num_gpus > 0:
    # Avoid gpu allocation conflict: now allocate task_num -> #gpu
    # for each worker in the corresponding machine
    gpu = (FLAGS.task_index % FLAGS.num_gpus)
    # 分配worker到指定gpu上运行
    worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
  # 如果使用cpu
  elif FLAGS.num_gpus == 0:
    # Just allocate the CPU to worker server
    # 把cpu分配给worker
    cpu = 0
    worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
  # The device setter will automatically place Variables ops on separate
  # parameter servers (ps). The non-Variable ops will be placed on the workers.
  # The ps use CPU and workers use corresponding GPU
  # 用tf.train.replica_device_setter将涉及变量操作分配到参数服务器上,使用CPU。将涉及非变量操作分配到工作节点上,使用上一步worker_device值。
  # 在这个with语句之下定义的参数,会自动分配到参数服务器上去定义。如果有多个参数服务器,就轮流循环分配
  with tf.device(
      tf.train.replica_device_setter(
          worker_device=worker_device,
          ps_device="/job:ps/cpu:0",
          cluster=cluster)):

    # 定义全局步长,默认值为0
    global_step = tf.Variable(0, name="global_step", trainable=False)
    # Variables of the hidden layer
    # 定义隐藏层参数变量,这里是全连接神经网络隐藏层
    hid_w = tf.Variable(
        tf.truncated_normal(
            [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
            stddev=1.0 / IMAGE_PIXELS),
        name="hid_w")
    hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
    # Variables of the softmax layer
    # 定义Softmax 回归层参数变量
    sm_w = tf.Variable(
        tf.truncated_normal(
            [FLAGS.hidden_units, 10],
            stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
        name="sm_w")
    sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
    # Ops: located on the worker specified with FLAGS.task_index
    # 定义模型输入数据变量
    x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, [None, 10])
    # 构建隐藏层
    hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
    hid = tf.nn.relu(hid_lin)
    # 构建损失函数和优化器
    y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
    cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    # 异步训练模式:自己计算完成梯度就去更新参数,不同副本之间不会去协调进度
    opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    # 同步训练模式
    if FLAGS.sync_replicas:
      if FLAGS.replicas_to_aggregate is None:
        replicas_to_aggregate = num_workers
      else:
        replicas_to_aggregate = FLAGS.replicas_to_aggregate
      # 使用SyncReplicasOptimizer作优化器,并且是在图间复制情况下
      # 在图内复制情况下将所有梯度平均
      opt = tf.train.SyncReplicasOptimizer(
          opt,
          replicas_to_aggregate=replicas_to_aggregate,
          total_num_replicas=num_workers,
          name="mnist_sync_replicas")
    train_step = opt.minimize(cross_entropy, global_step=global_step)
    if FLAGS.sync_replicas:
      local_init_op = opt.local_step_init_op
      if is_chief:
        # 所有进行计算工作节点里一个主工作节点(chief)
        # 主节点负责初始化参数、模型保存、概要保存
        local_init_op = opt.chief_init_op
      ready_for_local_init_op = opt.ready_for_local_init_op
      # Initial token and chief queue runners required by the sync_replicas mode
      # 同步训练模式所需初始令牌、主队列
      chief_queue_runner = opt.get_chief_queue_runner()
      sync_init_op = opt.get_init_tokens_op()
    init_op = tf.global_variables_initializer()
    train_dir = tempfile.mkdtemp()
    if FLAGS.sync_replicas:
      # 创建一个监管程序,用于统计训练模型过程中的信息
      # lodger 是保存和加载模型路径
      # 启动就会去这个logdir目录看是否有检查点文件,有的话就自动加载
      # 没有就用init_op指定初始化参数
      # 主工作节点(chief)负责模型参数初始化工作
      # 过程中,其他工作节点等待主节眯完成初始化工作,初始化完成后,一起开始训练数据
      # global_step值是所有计算节点共享的
      # 在执行损失函数最小值时自动加1,通过global_step知道所有计算节点一共计算多少步
      sv = tf.train.Supervisor(
          is_chief=is_chief,
          logdir=train_dir,
          init_op=init_op,
          local_init_op=local_init_op,
          ready_for_local_init_op=ready_for_local_init_op,
          recovery_wait_secs=1,
          global_step=global_step)
    else:
      sv = tf.train.Supervisor(
          is_chief=is_chief,
          logdir=train_dir,
          init_op=init_op,
          recovery_wait_secs=1,
          global_step=global_step)
    # 创建会话,设置属性allow_soft_placement为True
    # 所有操作默认使用被指定设置,如GPU
    # 如果该操作函数没有GPU实现,自动使用CPU设备
    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
    # The chief worker (task_index==0) session will prepare the session,
    # while the remaining workers will wait for the preparation to complete.
    # 主工作节点(chief),task_index为0节点初始化会话
    # 其余工作节点等待会话被初始化后进行计算
    if is_chief:
      print("Worker %d: Initializing session..." % FLAGS.task_index)
    else:
      print("Worker %d: Waiting for session to be initialized..." %
            FLAGS.task_index)
    if FLAGS.existing_servers:
      server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
      print("Using existing server at: %s" % server_grpc_url)
      # 创建TensorFlow会话对象,用于执行TensorFlow图计算
      # prepare_or_wait_for_session需要参数初始化完成且主节点准备好后,才开始训练
      sess = sv.prepare_or_wait_for_session(server_grpc_url,
                                            config=sess_config)
    else:
      sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
    print("Worker %d: Session initialization complete." % FLAGS.task_index)
    if FLAGS.sync_replicas and is_chief:
      # Chief worker will start the chief queue runner and call the init op.
      sess.run(sync_init_op)
      sv.start_queue_runners(sess, [chief_queue_runner])
    # Perform training
    # 执行分布式模型训练
    time_begin = time.time()
    print("Training begins @ %f" % time_begin)
    local_step = 0
    while True:
      # Training feed
      # 读入MNIST训练数据,默认每批次100张图片
      batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
      train_feed = {x: batch_xs, y_: batch_ys}
      _, step = sess.run([train_step, global_step], feed_dict=train_feed)
      local_step += 1
      now = time.time()
      print("%f: Worker %d: training step %d done (global step: %d)" %
            (now, FLAGS.task_index, local_step, step))
      if step >= FLAGS.train_steps:
        break
    time_end = time.time()
    print("Training ends @ %f" % time_end)
    training_time = time_end - time_begin
    print("Training elapsed time: %f s" % training_time)
    # Validation feed
    # 读入MNIST验证数据,计算验证的交叉熵
    val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
    val_xent = sess.run(cross_entropy, feed_dict=val_feed)
    print("After %d training step(s), validation cross entropy = %g" %
          (FLAGS.train_steps, val_xent))
if __name__ == "__main__":
  tf.app.run()

参考资料: 《TensorFlow技术解析与实战》

欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi

暂无回复
登录 或者 注册