def collect_gradients(gradients, variables):
ops = []
for grad, var in zip(gradients, variables):
if isinstance(grad, tf.Tensor):
ops.append(tf.assign_add(var, grad))
else:
ops.append(tf.scatter_add(var, grad.indices, grad.values))
return tf.group(*ops)
评论列表
文章目录