def decode_from_ternary_gradients(grads_and_vars, scalers, shapes):
"""Decode each gradient tensor."""
with tf.name_scope('ternary_decoder'):
gradients, variables = zip(*grads_and_vars)
floating_gradients = []
for gradient, variable, scaler, shape in zip(gradients, variables, scalers, shapes):
if gradient is None:
floating_gradients.append(None)
# gradient is encoded, so we use variable to check its size
# We also assume dtype of variable and gradient is the same
floating_gradient = tf.cond(tf.size(variable) < FLAGS.size_to_binarize,
lambda: tf.bitcast(gradient, variable.dtype),
lambda: ternary_decoder(gradient, scaler, shape))
floating_gradients.append(floating_gradient)
return list(zip(floating_gradients, variables))
评论列表
文章目录