bingrad_common.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def encode_to_ternary_gradients(grads_and_vars, get_shape=False):
  """Encode each gradient tensor."""
  with tf.name_scope('ternary_encoder'):
    gradients, variables = zip(*grads_and_vars)
    ternary_gradients = []
    gradient_shapes = []
    for gradient in gradients:
      if gradient is None:
        ternary_gradients.append(None)
        if get_shape:
          gradient_shapes.append(None)
        continue

      if get_shape:
        if isinstance(gradient, tf.IndexedSlices):
          gradient_shape = gradient.dense_shape
        else:
          gradient_shape = gradient.get_shape()
        gradient_shapes.append(gradient_shape)

      ternary_gradient = tf.cond(tf.size(gradient) < FLAGS.size_to_binarize,
                                 lambda: tf.bitcast(gradient, type=tf.uint8),
                                 lambda: ternary_encoder(gradient))
      ternary_gradients.append(ternary_gradient)

    if get_shape:
      return list(zip(ternary_gradients, variables)), gradient_shapes
    else:
      return list(zip(ternary_gradients, variables))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号