다음은 하나의 매개 변수 서버에 저장되고 각 작업자가 비동기 적으로 증가시키는 전역 카운터를 구현하는 분산 Tensorflow 코드의 두 가지 버전입니다.분산 된 Tensorflow는 이러한 종류의 tf를 처리합니다. 가변 생성?
두 버전 모두 같은 것을 인쇄하는 것으로 보이지만이 이유는 알 수 없습니다. 버전 간의 차이점은 주석으로 표시된 두 줄에 있습니다 (# NEW
).
각 작업자가 버전 1을 실행하면 매개 변수 서버가 각 작업자에 대해 local_counter
tf.Variable
을 자동으로 저장합니까?
버전 2에서는 각 매개 변수 서버에 local_counter
tf.Variable
을 명시 적으로 넣으려고합니다.
다음 버전 1 또는 버전 2는 실제로 차이가 있습니까?
PS : 모든 인스턴스에서 tf.Variable
을 (를) 관리하는 가장 좋은 방법이 아니므로 개선에 대한 조언을 드리고자합니다. 감사!
버전 1
# Standard distributed Tensorflow boilerplate
# ...
elif FLAGS.job_name == 'worker':
TASK = FLAGS.task_index
with tf.device('/job:ps/task:0/cpu:0'):
with tf.variable_scope('global'):
global_counter = tf.Variable(0, name='global_counter',
trainable=False)
local_counter = tf.Variable(0, name='local_counter_{}'.format(TASK),
trainable=False)
init_op = tf.global_variables_initializer()
with tf.device('/job:worker/task:{}'.format(TASK)):
with tf.variable_scope('local'):
local_inc_op = local_counter.assign_add(1)
global_inc_op = global_counter.assign_add(1)
with tf.Session(server.target):
sess.run(init_op)
global_count = 0
while global_count < 1000:
sess.run([local_inc_op, global_inc_op])
local_count, global_count = sess.run([local_counter, global_counter])
print('Local {}, Global {}, worker-{}'.format(
local_count, global_count, TASK))
버전 2
# Standard distributed Tensorflow boilerplate
# ...
elif FLAGS.job_name == 'worker':
NUM_WORKERS = len(worker_hosts)
TASK = FLAGS.task_index
with tf.device('/job:ps/task:0/cpu:0'):
with tf.variable_scope('global'):
global_counter = tf.Variable(0, name='global_counter',
trainable=False)
local_counters = [tf.Variable(0, name='local_counter_{}'.format(i),
trainable=False)
for i in range(NUM_WORKERS)] # NEW
init_op = tf.global_variables_initializer()
with tf.device('/job:worker/task:{}'.format(TASK)):
with tf.variable_scope('local'):
local_counter = local_counters[TASK] # NEW
local_inc_op = local_counter.assign_add(1)
global_inc_op = global_counter.assign_add(1)
with tf.Session(server.target):
sess.run(init_op)
global_count = 0
while global_count < 1000:
sess.run([local_inc_op, global_inc_op])
local_count, global_count = sess.run([local_counter, global_counter])
print('Local {}, Global {}, worker-{}'.format(
local_count, global_count, TASK))
Gotcha, 차이점은 근로자의 그래프입니다. 그렇지 않으면 기본적으로 기능이 동일합니다. 감사합니다, Allen! ResourceMgr을 살펴 보겠습니다. – awalllllll