tensorflow中实现分布式的Barrier挺让人头疼

分布式training中,经常需要让所有的节点在某一个点同步一下。比如,

  1. 只有id=0的worker进行完所有初始化操作后,其它的worker才能开始运行training
  2. 只有所有的worker都train完后,才能开始dump model
那么怎么实现这样的同步功能呢?最naive的做法是,声明一个int类型、形状为scalar 的 tensorflow variable。这个variable的初始值为0,每个worker到达sync point的时候,就把这个tensor的值加1。然后不停的sleep && get value。如下面代码所示:
声明
finished = tf.get_variable("worker_finished",
   [],tf.int32,tf.zeros_initializer(tf.int32),trainable=False)                    
with finished.graph.colocate_with(finished):
 finish_op = finished.assign_add(1,use_locking=True)
使用:
worker_finished = sess.run(finish_op)
print('%d worker finished' % worker_finished)
if is_chief:
 try:
  while worker_finished < worker_count:               
   time.sleep(3) 
   worker_finished = raw_sess.run(finished)
   print('%d worker finished' % worker_finished)
 except Exception as ex:
  print('exit with error:%s' % str(ex))
  return

但是这样有个缺点:所有的variable默认都会被写入到checkpoint中,这个也不例外。所以,当下次从checkpoint中再载入时,就乱套了。能不能不让saver保存这个variable? 可以。但是如果session创建时是从checkpoint恢复,那么默认不会运行init_op。于是这个variable就会变成未初始化的状态。为了解决这个问题,接下来又会引发一连串的与并发相关的问题。

最后我自己实现了一个,功能和Variable很像的operator。它没有init_op,永远默认初始化为0。问题就全解决了。主要注意的是该resource的shared_name不能为空字符串。




此博客中的热门博文

少写代码,多读别人写的代码

在windows下使用llvm+clang

tensorflow distributed runtime初窥