ternary.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ternarynet 作者: czhu95 项目源码 文件源码
def tw_ternarize(x, thre):

    shape = x.get_shape()

    thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thre)

    w_p = tf.get_variable('Wp', collections=[tf.GraphKeys.VARIABLES, 'positives'], initializer=1.0)
    w_n = tf.get_variable('Wn', collections=[tf.GraphKeys.VARIABLES, 'negatives'], initializer=1.0)

    tf.scalar_summary(w_p.name, w_p)
    tf.scalar_summary(w_n.name, w_n)

    mask = tf.ones(shape)
    mask_p = tf.select(x > thre_x, tf.ones(shape) * w_p, mask)
    mask_np = tf.select(x < -thre_x, tf.ones(shape) * w_n, mask_p)
    mask_z = tf.select((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)

    with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
        w =  tf.sign(x) * tf.stop_gradient(mask_z)

    w = w * mask_np

    tf.histogram_summary(w.name, w)
    return w
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号