def ternary_decoder(encoded_data, scaler, shape):
"""Decoding the signs to float format """
a = tf.cast(encoded_data, tf.int32)
a_split1 = tf.mod(a,4)
a_split2 = tf.to_int32(tf.mod(a/4,4))
a_split3 = tf.to_int32(tf.mod(a/16,4))
a_split4 = tf.to_int32(tf.mod(a/64,4))
a = tf.concat([a_split1, a_split2, a_split3, a_split4], 0)
real_size = tf.reduce_prod(shape)
a = tf.to_float(a)
a = tf.gather(a, tf.range(0,real_size))
a = tf.reshape(a, shape)
a = tf.subtract(a,1)
decoded = a*scaler
return decoded
评论列表
文章目录