分布式tensorflow源码解读2:MonitoredTrainingSession

MonitoredTrainingSession是tensorflow管理分布式训练中一个使用很广泛的API,集成了一些监控训练组件,如变量的初始化、从已有checkpoint恢复训练、summary、log和checkpoint的保存等。在早期的tf版本中,一般使用tf.train.Supervisor来管理session,后来框架升级后,官方推荐使用MonitoredTrainingSession。MonitoredTrainingSession有记录日志、训练可视化、checkpoint保存、early-stop、训练效率调优等功能。

我们直接进入主题,下面是MonitoredTrainingSession源码,从注释中可了解到:MonitoredTrainingSession的作用可用一句话来概括:如果chief节点,负责session的初始化或者从已有checkpoint恢复session,并且创建一些用于保存checkpoint和summary的hooks。如果是非chief的worker节点,则需要依赖chief节点完成初始化或恢复session这些操作后才能设置属于自己的session。

@tf_export(v1=['train.MonitoredTrainingSession'])
def MonitoredTrainingSession(
    master='',  # pylint: disable=invalid-name
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=USE_DEFAULT,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200,
    save_checkpoint_steps=USE_DEFAULT,
    summary_dir=None):

  """
  Creates a `MonitoredSession` for training.
  Returns:
    A `MonitoredSession` object.
  """
  
  scaffold = scaffold or Scaffold()
  worker_context = distribute_coordinator_context.get_current_worker_context()

  if worker_context:
    return _create_monitored_session_with_worker_context(
        worker_context,
        scaffold,
        checkpoint_dir=checkpoint_dir,
        hooks=hooks,
        chief_only_hooks=chief_only_hooks,
        save_checkpoint_secs=save_checkpoint_secs,
        save_summaries_steps=save_summaries_steps,
        save_summaries_secs=save_summaries_secs,
        config=config,
        stop_grace_period_secs=stop_grace_period_secs,
        log_step_count_steps=log_step_count_steps,
        max_wait_secs=max_wait_secs,
        save_checkpoint_steps=save_checkpoint_steps,
        summary_dir=summary_dir)

  if not is_chief:
    session_creator = WorkerSessionCreator(
        scaffold=scaffold,
        master=master,
        config=config,
        max_wait_secs=max_wait_secs)
    return MonitoredSession(
        session_creator=session_creator,
        hooks=hooks or [],
        stop_grace_period_secs=stop_grace_period_secs)

  all_hooks = []
  “”“
  将多个hook都加入到all_hooks这个列表中
  ”“”
  if hooks:
    all_hooks.extend(hooks)

  return MonitoredSession(
      session_creator=session_creator,
      hooks=all_hooks,
      stop_grace_period_secs=stop_grace_period_secs)

我们首先解释下参数:

is_chief:用于分布式系统中,用于判断该系统是否是chief,如果为True,它将负责初始化并恢复底层TensorFlow会话。如果为False,它将等待chief初始化或恢复TensorFlow会话。

checkpoint_dir:一个字符串。指定一个用于恢复变量的checkpoint文件路径。

scaffold:用于收集或建立支持性op的脚手架。如果未指定,则会创建默认一个默认的scaffold。它用于完成图表的创建。

hooks:SessionRunHook对象的可选列表。可自己定义SessionRunHook对象,也可用已经预定义好的SessionRunHook对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果loss的值为Nan则停止训练;

chief_only_hooks:SessionRunHook对象列表。如果is_chief== True,则激活这些挂钩,否则忽略。

save_checkpoint_secs:用默认的checkpoint saver保存checkpoint的频率(以秒为单位)。如果save_checkpoint_secs设置为None,不保存checkpoint。

save_summaries_steps:使用默认summaries saver将摘要写入磁盘的频率(以全局步数表示)。如果save_summaries_steps和save_summaries_secs都设置为None,则不使用默认的summaries saver保存summaries。默认为100

save_summaries_secs:使用默认summaries saver将摘要写入磁盘的频率(以秒为单位)。如果save_summaries_steps和save_summaries_secs都设置为None,则不使用默认的摘要保存。默认未启用。

config:用于配置会话的tf.ConfigProtoproto的实例。它是tf.Session的构造函数的config参数。

stop_grace_period_secs:调用close()后线程停止的秒数。

log_step_count_steps:记录全局步/秒的全局步数的频率。

实例化后可得到一个MonitoredSession对象,可当作普通session使用。

然后我们仔细分解下代码:

def _create_monitored_session_with_worker_context(
    worker_context,  # pylint: disable=missing-docstring
    scaffold,
    checkpoint_dir=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=None,
    save_summaries_steps=None,
    save_summaries_secs=None,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200,
    save_checkpoint_steps=None,
    summary_dir=None):
  all_hooks = []

  “”“
  将多个hook都加入到all_hooks这个列表中
  ”“”

  logging.info('all_hooks %r', all_hooks)
  # 创建session
  session_creator = worker_context.session_creator(
      scaffold,
      config=config,
      checkpoint_dir=checkpoint_dir,
      max_wait_secs=max_wait_secs)

  return MonitoredSession(
      session_creator=session_creator,
      hooks=all_hooks,
      stop_grace_period_secs=stop_grace_period_secs)

  # session_creator 函数主体
  def session_creator(self,
                      scaffold=None,
                      config=None,
                      checkpoint_dir=None,
                      checkpoint_filename_with_path=None,
                      max_wait_secs=7200):
    """
    根据正确master的target和session的config去返回session的creator方法体。
    """
    if config:
      session_config = copy.deepcopy(config)
      session_config.MergeFrom(self._session_config)
    else:
      session_config = self._session_config
    
    # 根据不同的角色来创建session
    if not self._strategy or self._strategy.extended.experimental_should_init:
      logging.info("Creating chief session creator with config: %r", config)
      return monitored_session.ChiefSessionCreator(
          scaffold,
          master=self.master_target,
          config=session_config,
          checkpoint_dir=checkpoint_dir,
          checkpoint_filename_with_path=checkpoint_filename_with_path)
    else:
      logging.info("Creating worker session creator with config: %r", config)
      return monitored_session.WorkerSessionCreator(
          scaffold,
          master=self.master_target,
          config=session_config,
          max_wait_secs=max_wait_secs)

# ChiefSessionCreator
@tf_export(v1=['train.ChiefSessionCreator'])
class ChiefSessionCreator(SessionCreator):
  """Creates a tf.compat.v1.Session for a chief."""
  def __init__(self,
               scaffold=None,
               master='',
               config=None,
               checkpoint_dir=None,
               checkpoint_filename_with_path=None):
    self._checkpoint_dir = checkpoint_dir
    self._checkpoint_filename_with_path = checkpoint_filename_with_path
    self._scaffold = scaffold or Scaffold()
    self._session_manager = None
    self._master = master
    self._config = config

  def _get_session_manager(self):
    if self._session_manager:
      return self._session_manager

    self._session_manager = sm.SessionManager(
        local_init_op=self._scaffold.local_init_op,
        ready_op=self._scaffold.ready_op,
        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
        graph=ops.get_default_graph())
    return self._session_manager

  def create_session(self):
    self._scaffold.finalize()
    return self._get_session_manager().prepare_session(
        self._master,
        saver=self._scaffold.saver,
        checkpoint_dir=self._checkpoint_dir,
        checkpoint_filename_with_path=self._checkpoint_filename_with_path,
        config=self._config,
        init_op=self._scaffold.init_op,
        init_feed_dict=self._scaffold.init_feed_dict,
        init_fn=self._scaffold.init_fn)

# WorkerSessionCreator
@tf_export(v1=['train.WorkerSessionCreator'])
class WorkerSessionCreator(SessionCreator):
  """Creates a tf.compat.v1.Session for a worker."""
  def __init__(self,
               scaffold=None,
               master='',
               config=None,
               max_wait_secs=30 * 60):
    """Initializes a worker session creator.

    Args:
      max_wait_secs: Maximum time to wait for the session to become available.
    """
    self._scaffold = scaffold or Scaffold()
    self._session_manager = None
    self._master = master
    self._config = config
    self._max_wait_secs = max_wait_secs

  def _get_session_manager(self):
    if self._session_manager:
      return self._session_manager

    self._session_manager = sm.SessionManager(
        local_init_op=self._scaffold.local_init_op,
        ready_op=self._scaffold.ready_op,
        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
        graph=ops.get_default_graph())
    return self._session_manager

  def create_session(self):
    self._scaffold.finalize()
    return self._get_session_manager().wait_for_session(
        self._master, config=self._config, max_wait_secs=self._max_wait_secs)

从上面的源码中分析得到,MonitoredTrainingSession可根据不同的角色去创建不同种类的Session,其中chief节点是由ChiefSessionCreator类去创建session,而非chief的worker节点是由WorkerSessionCreator类创建,特殊之处就是创建时调用的是wait_for_session(),大致意识是需要等待chief节点的session创建完成之后才去创建属于自己节点的session。其中创建session都是属于SessionManager类的一个方法,下面我们具体分析下SessionManager类:

官方针对SessionManager类有一个简单的例子,感觉很清楚:

  # prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver` 
  with tf.Graph().as_default():
    # add operations to the graph...
    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
    sm = SessionManager()
    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
    # Use the session to train the graph.
    while True:
      sess.run(<my_train_op>)
  
  # 第二个进程可以用以下方法启动op,wait_for_session()的意思是需要等上面一个session创建好之后
  # 再创建自己的session
  with tf.Graph().as_default():
    # ...add operations to the graph...
    # Create a SessionManager that will wait for the model to become ready.
    sm = SessionManager()
    sess = sm.wait_for_session(master)
    # Use the session to train the graph.
    while True:
      sess.run(<my_train_op>)

然后我们可以重点关注下prepare_session和wait_for_session这两个函数:

@tf_export(v1=["train.SessionManager"])
class SessionManager(object):
  def __init__(self,
               local_init_op=None,
               ready_op=None,
               ready_for_local_init_op=None,
               graph=None,
               recovery_wait_secs=30,
               local_init_run_options=None):
    """
    local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。
    ready_op 用于check模型是否准备好的一个op。
    ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
    """
    # Sets default values of arguments.
    if graph is None:
      graph = ops.get_default_graph()
    self._local_init_op = local_init_op
    self._ready_op = ready_op
    self._ready_for_local_init_op = ready_for_local_init_op
    self._graph = graph
    self._recovery_wait_secs = recovery_wait_secs
    self._target = None
    self._local_init_run_options = local_init_run_options
    if ready_for_local_init_op is not None and local_init_op is None:
      raise ValueError("If you pass a ready_for_local_init_op "
                       "you must also pass a local_init_op "
                       ", ready_for_local_init_op [%s]" %
                       ready_for_local_init_op)

  def prepare_session(self,
                      master,
                      init_op=None,
                      saver=None,
                      checkpoint_dir=None,
                      checkpoint_filename_with_path=None,
                      wait_for_checkpoint=False,
                      max_wait_secs=7200,
                      config=None,
                      init_feed_dict=None,
                      init_fn=None):
    """
    其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果
    不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
    """
    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
        master,
        saver,
        checkpoint_dir=checkpoint_dir,
        checkpoint_filename_with_path=checkpoint_filename_with_path,
        wait_for_checkpoint=wait_for_checkpoint,
        max_wait_secs=max_wait_secs,
        config=config)
    if not is_loaded_from_checkpoint:
      if init_op is None and not init_fn and self._local_init_op is None:
        raise RuntimeError("Model is not initialized and no init_op or "
                           "init_fn or local_init_op was given")
      if init_op is not None:
        sess.run(init_op, feed_dict=init_feed_dict)
      if init_fn:
        init_fn(sess)
    ”“”
    .....
    “”“
    return sess


  def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
    """
     Creates a new `Session` and waits for model to be ready.
    """
    self._target = master
    if max_wait_secs is None:
      max_wait_secs = float("Inf")
    timer = _CountDownTimer(max_wait_secs)

    while True:
      sess = session.Session(self._target, graph=self._graph, config=config)
      not_ready_msg = None
      not_ready_local_msg = None
      local_init_success, not_ready_local_msg = self._try_run_local_init_op(
          sess)
      if local_init_success:
        # Successful if local_init_op is None, or ready_for_local_init_op passes
        is_ready, not_ready_msg = self._model_ready(sess)
        if is_ready:
          return sess
      self._safe_close(sess)
      # Do we have enough time left to try again?
      remaining_ms_after_wait = (
          timer.secs_remaining() - self._recovery_wait_secs)
      if remaining_ms_after_wait < 0:
        raise errors.DeadlineExceededError(
            None, None,
            "Session was not ready after waiting %d secs." % (max_wait_secs,))
      logging.info("Waiting for model to be ready.  "
                   "Ready_for_local_init_op:  %s, ready: %s",
                   not_ready_local_msg, not_ready_msg)
      time.sleep(self._recovery_wait_secs)

创建完session之后,再包装一下返回最终的MonitoredSession类,

一个完整的monitored session在创建时间内可做的事情(按顺序):

  • 为每个hook调用hook.begin()
  • 调用scaffold.finalize()完成graph
  • 创建session
  • 为模型参数做初始化 ,通过Scaffold
  • 如果存在checkpoint则根据checkpoint restore参数
  • 发布runners队列
  • 调用hook.after_create_session()函数

当run函数调用时,monitored session做的事情:

  • 调用hook.before_run()
  • 调用TensorFlow中的 `session.run()` with merged fetches and feed_dict
  • 调用hook.after_run()
  • 返回session.run()的结果
  • 如果发生AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话

当close()函数调用时,monitored session做的事情:

  • 调用 hook.end()
  • 关闭queue runners 和session
  • 如果所有的输入数据被消耗完,抛出OutOfRange异常。

最后,给大家贴一个使用MonitoredSession类进行分布式训练的example:

from __future__ import print_function, absolute_import, division

import tensorflow as tf

tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")
tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")
tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")
tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")

FLAGS = tf.app.flags.FLAGS


def model(images):
    """Define a simple mnist classifier"""
    net = tf.layers.dense(images, 500, activation=tf.nn.relu)
    net = tf.layers.dense(net, 500, activation=tf.nn.relu)
    net = tf.layers.dense(net, 10, activation=None)
    return net


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32')
x_test = x_test.reshape(-1, 784).astype('float32')
x_train /= 255
x_test /= 255


def get_batch(image, label, batch_size=32, training=True):
    df = tf.data.Dataset.from_tensor_slices((image, label))
    if training:
        df = df.repeat(10).shuffle(buffer_size=1000)
    df = df.batch(batch_size).prefetch(batch_size)
    iterator = df.make_one_shot_iterator()
    batch_x, batch_y = iterator.get_next()
    return batch_x, batch_y


def main(_):
    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")

    # create the cluster configured by `ps_hosts' and 'worker_hosts'
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # create a server for local task
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

    train_batch_x, train_batch_y = get_batch(x_train, y_train)
    test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)

    if FLAGS.job_name == "ps":
        server.join()  # ps hosts only join
    elif FLAGS.job_name == "worker":
        # workers perform the operation
        # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)

        # Note: tf.train.replica_device_setter automatically place the paramters (Variables)
        # on the ps hosts (default placement strategy:  round-robin over all ps hosts, and also
        # place multi copies of operations to each worker host
        with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,
                                                      cluster=cluster)):

            logits = model(train_batch_x)
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))

            # The StopAtStepHook handles stopping after running given steps.
            hooks = [tf.train.StopAtStepHook(last_step=10000)]

            global_step = tf.train.get_or_create_global_step()
            optimizer = tf.train.AdamOptimizer(learning_rate=1e-04)
            
            if FLAGS.is_sync:
                # asynchronous training
                # use tf.train.SyncReplicasOptimizer wrap optimizer
                # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
                optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers,
                                                           total_num_replicas=FLAGS.num_workers)
                # create the hook which handles initialization and queues
                hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))

            train_op = optimizer.minimize(loss, global_step=global_step)

            # The MonitoredTrainingSession takes care of session initialization,
            # restoring from a checkpoint, saving to a checkpoint, and closing when done
            # or an error occurs.
            with tf.train.MonitoredTrainingSession(master=server.target,
                                                   is_chief=(FLAGS.task_index == 0),
                                                   checkpoint_dir="./checkpoint_dir",
                                                   hooks=hooks) as mon_sess:
                while not mon_sess.should_stop():
                    # mon_sess.run handles AbortedError in case of preempted PS.
                    _, ls, step = mon_sess.run([train_op, loss, global_step])
                    if step % 100 == 0:
                        print("Train step %d, loss: %f" % (step, ls))


if __name__ == "__main__":
    tf.app.run()

参考文献:

编辑于 2019-11-13

文章被以下专栏收录