def tw_ternarize(x, thre):
shape = x.get_shape()
thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thre)
w_p = tf.get_variable('Wp', collections=[tf.GraphKeys.VARIABLES, 'positives'], initializer=1.0)
w_n = tf.get_variable('Wn', collections=[tf.GraphKeys.VARIABLES, 'negatives'], initializer=1.0)
tf.scalar_summary(w_p.name, w_p)
tf.scalar_summary(w_n.name, w_n)
mask = tf.ones(shape)
mask_p = tf.select(x > thre_x, tf.ones(shape) * w_p, mask)
mask_np = tf.select(x < -thre_x, tf.ones(shape) * w_n, mask_p)
mask_z = tf.select((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)
with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
w = tf.sign(x) * tf.stop_gradient(mask_z)
w = w * mask_np
tf.histogram_summary(w.name, w)
return w
评论列表
文章目录