def stochastical_binarize_gradients(grads_and_vars, scalers):
"""Stochastically binarize gradients."""
gradients, variables = zip(*grads_and_vars)
binarized_gradients = []
for gradient, scaler in zip(gradients, scalers):
if gradient is None:
binarized_gradients.append(None)
continue
if isinstance(gradient, tf.IndexedSlices):
gradient_shape = gradient.dense_shape
else:
gradient_shape = gradient.get_shape()
zeros = tf.zeros(gradient_shape)
abs_gradient = tf.abs(gradient)
sign_gradient = tf.sign( gradient )
rnd_sample = tf.random_uniform(gradient_shape,0,scaler)
where_cond = tf.less(rnd_sample, abs_gradient)
binarized_gradient = tf.cond(tf.size(gradient) < FLAGS.size_to_binarize,
lambda: gradient,
lambda: tf.where(where_cond, sign_gradient * scaler, zeros))
binarized_gradients.append(binarized_gradient)
return list(zip(binarized_gradients, variables))
评论列表
文章目录