ternary.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ternarynet 作者: czhu95 项目源码 文件源码
def p_ternarize(x, p):

    x = tf.tanh(x)
    shape = x.get_shape()

    thre = tf.get_variable('T', trainable=False, collections=[tf.GraphKeys.VARIABLES, 'thresholds'],
            initializer=0.05)
    flat_x = tf.reshape(x, [-1])
    k = int(flat_x.get_shape().dims[0].value * (1 - p))
    topK, _ = tf.nn.top_k(tf.abs(flat_x), k)
    update_thre = thre.assign(topK[-1])
    tf.add_to_collection('update_thre_op', update_thre)

    mask = tf.zeros(shape)
    mask = tf.select((x > thre) | (x < -thre), tf.ones(shape), mask)

    with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
        w =  tf.sign(x) * tf.stop_gradient(mask)

    tf.histogram_summary(w.name, w)
    return w
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号