bingrad_common.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def stochastical_binarize_gradients(grads_and_vars, scalers):
  """Stochastically binarize gradients."""
  gradients, variables = zip(*grads_and_vars)
  binarized_gradients = []
  for gradient, scaler in zip(gradients, scalers):
    if gradient is None:
      binarized_gradients.append(None)
      continue
    if isinstance(gradient, tf.IndexedSlices):
      gradient_shape = gradient.dense_shape
    else:
      gradient_shape = gradient.get_shape()

    zeros = tf.zeros(gradient_shape)
    abs_gradient = tf.abs(gradient)
    sign_gradient = tf.sign( gradient )
    rnd_sample = tf.random_uniform(gradient_shape,0,scaler)
    where_cond = tf.less(rnd_sample, abs_gradient)
    binarized_gradient = tf.cond(tf.size(gradient) < FLAGS.size_to_binarize,
                               lambda: gradient,
                               lambda: tf.where(where_cond, sign_gradient * scaler, zeros))

    binarized_gradients.append(binarized_gradient)
  return list(zip(binarized_gradients, variables))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号