def all_sync_params(tower_params, devices, usenccl=True):
"""Assigns the params from the first tower to all others"""
if len(devices) == 1:
return tf.no_op()
sync_ops = []
if have_nccl and usenccl:
for param_on_devices in zip(*tower_params):
# print('PARAM_ON_DEVICES: {}'.format(param_on_devices)) # DEBUG
# Note: param_on_devices is [paramX_gpu0, paramX_gpu1, ...]
param0 = param_on_devices[0]
send_op, received_tensors = nccl.broadcast(param0, devices[1:])
sync_ops.append(send_op)
for device, param, received in zip(devices[1:],
param_on_devices[1:],
received_tensors):
with tf.device(device):
sync_op = param.assign(received)
sync_ops.append(sync_op)
else:
params0 = tower_params[0]
for device, params in zip(devices, tower_params):
with tf.device(device):
for param, param0 in zip(params, params0):
sync_op = param.assign(param0.read_value())
sync_ops.append(sync_op)
return tf.group(*sync_ops)
# def stage(tensors):
# """Stages the given tensors in a StagingArea for asynchronous put/get.
# """
# stage_area = data_flow_ops.StagingArea(
# dtypes=[tensor.dtype for tensor in tensors],
# shapes=[tensor.get_shape() for tensor in tensors])
# put_op = stage_area.put(tensors)
# get_tensors = stage_area.get()
# if not isinstance(get_tensors, list):
# get_tensors = [get_tensors]
# # print('GET_TENSORS: {}'.format(get_tensors)) # DEBUG
#
# get_tensors = [tf.reshape(gt, t.get_shape())
# for (gt, t) in zip(get_tensors, tensors)]
# return put_op, get_tensors
评论列表
文章目录