test_ternary_encoder_decoder.py 文件源码

python
阅读 55 收藏 0 点赞 0 评论 0

项目:terngrad 作者: wenwei202 项目源码 文件源码
def ternary_encoder(input_data):
  """Encoding and compressing the signs """
  a = tf.sign(input_data) # -1, 0, 1
  a = tf.add(a,1) # shift -1,0,1 to 0,1,2 (2'b00,2'b01,2'b10)
  a = tf.reshape(a,[-1])
  pad_size = 4 - tf.mod(tf.size(a), 4)
  pad = tf.range(0.0, pad_size)
  a = tf.concat([a, pad], 0)
  a_split1, a_split2, a_split3, a_split4 = tf.split(a,4) # assume the size is dividable by 4

  # encode 4 grads into 1 Byte
  sum_1 = tf.add(a_split1, a_split2*4)
  sum_2 = tf.add(a_split3*16, a_split4*64)
  sum_all = tf.add(sum_1, sum_2)
  encoded = tf.cast(sum_all, tf.uint8)
  return encoded
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号