def encode_to_ternary_gradients(grads_and_vars, get_shape=False):
"""Encode each gradient tensor."""
with tf.name_scope('ternary_encoder'):
gradients, variables = zip(*grads_and_vars)
ternary_gradients = []
gradient_shapes = []
for gradient in gradients:
if gradient is None:
ternary_gradients.append(None)
if get_shape:
gradient_shapes.append(None)
continue
if get_shape:
if isinstance(gradient, tf.IndexedSlices):
gradient_shape = gradient.dense_shape
else:
gradient_shape = gradient.get_shape()
gradient_shapes.append(gradient_shape)
ternary_gradient = tf.cond(tf.size(gradient) < FLAGS.size_to_binarize,
lambda: tf.bitcast(gradient, type=tf.uint8),
lambda: ternary_encoder(gradient))
ternary_gradients.append(ternary_gradient)
if get_shape:
return list(zip(ternary_gradients, variables)), gradient_shapes
else:
return list(zip(ternary_gradients, variables))
评论列表
文章目录