def p_ternarize(x, p):
x = tf.tanh(x)
shape = x.get_shape()
thre = tf.get_variable('T', trainable=False, collections=[tf.GraphKeys.VARIABLES, 'thresholds'],
initializer=0.05)
flat_x = tf.reshape(x, [-1])
k = int(flat_x.get_shape().dims[0].value * (1 - p))
topK, _ = tf.nn.top_k(tf.abs(flat_x), k)
update_thre = thre.assign(topK[-1])
tf.add_to_collection('update_thre_op', update_thre)
mask = tf.zeros(shape)
mask = tf.select((x > thre) | (x < -thre), tf.ones(shape), mask)
with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
w = tf.sign(x) * tf.stop_gradient(mask)
tf.histogram_summary(w.name, w)
return w
评论列表
文章目录