bingrad_common.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def decode_from_ternary_gradients(grads_and_vars, scalers, shapes):
  """Decode each gradient tensor."""
  with tf.name_scope('ternary_decoder'):
    gradients, variables = zip(*grads_and_vars)
    floating_gradients = []
    for gradient, variable, scaler, shape in zip(gradients, variables, scalers, shapes):
      if gradient is None:
        floating_gradients.append(None)
      # gradient is encoded, so we use variable to check its size
      # We also assume dtype of variable and gradient is the same
      floating_gradient = tf.cond(tf.size(variable) < FLAGS.size_to_binarize,
                                 lambda: tf.bitcast(gradient, variable.dtype),
                                 lambda: ternary_decoder(gradient, scaler, shape))
      floating_gradients.append(floating_gradient)

    return list(zip(floating_gradients, variables))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号