def ternary_encoder(input_data):
"""Encoding and compressing the signs """
a = tf.sign(input_data) # -1, 0, 1
a = tf.add(a,1) # shift -1,0,1 to 0,1,2 (2'b00,2'b01,2'b10)
a = tf.reshape(a,[-1])
pad_size = 4 - tf.mod(tf.size(a), 4)
pad = tf.range(0.0, pad_size)
a = tf.concat([a, pad], 0)
a_split1, a_split2, a_split3, a_split4 = tf.split(a,4) # assume the size is dividable by 4
# encode 4 grads into 1 Byte
sum_1 = tf.add(a_split1, a_split2*4)
sum_2 = tf.add(a_split3*16, a_split4*64)
sum_all = tf.add(sum_1, sum_2)
encoded = tf.cast(sum_all, tf.uint8)
return encoded
test_ternary_encoder_decoder.py 文件源码
python
阅读 55
收藏 0
点赞 0
评论 0
评论列表
文章目录