terpret_tf_log_runtime.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def weighted_sum(components, weights, scope=""):
    # n: num_components
    # b: batch_size
    # c: component_size
    with tf.name_scope(scope):
        weight_is_batched = (weights.get_shape().ndims == 2)
        if weight_is_batched:
            set_batch_size = tf.shape(weights)[0]
        else:
            set_batch_size = None
        components, is_batched = make_batch_consistent(components, set_batch_size=set_batch_size)
        components = tf.pack(components) # [n x b x c]

        weight_rank = weights.get_shape().ndims
        assert_rank_1_or_2(weight_rank)
        if weight_rank == 1:
            weights = tf.reshape(weights, [-1,1,1]) # [n x 1 x 1]
        elif weight_rank == 2:
            weights = tf.expand_dims(tf.transpose(weights, [1, 0]),2) # [n x b x 1]

        components += weights
        # TODO: change this to tf.reduce_logsumexp when it is relased
        w_sum = logsumexp(components, reduction_indices=0) # [b x c]
        if not is_batched: w_sum = tf.squeeze(w_sum) # [c]
    return w_sum
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号