terpret_tf_runtime.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def apply_factor(tensor, *args, **kwargs):
    scope = kwargs.pop("scope", "")
    with tf.name_scope(scope):
        n_args = len(args)

        if n_args is 0:
            tensor, output_size, error_symbol = tensor
            return one_hot(tensor, output_size, scope=scope)
        else:
            tensor, args = slice_out_int_literals(tensor, list(args))
            args, is_batched = make_batch_consistent(args)
            tensor, output_size, error_symbol = tensor

            # handle the case where all arguments were int literals
            tensor_dim_sizes = [dim.value for dim in tensor.get_shape()]
            if not tensor_dim_sizes:
                return one_hot(tensor, output_size, scope=scope)

            # Each arg is batch size x arg dim. Add dimensions to enable broadcasting.
            for i, arg in enumerate(args):
                for j in xrange(n_args):
                    if j == i: continue
                    args[i] = tf.expand_dims(args[i], j + 1)

            # compute joint before tensor is applied
            joint = 1
            for arg in args:
                joint = joint * arg

            # prepare for unsorted_segment_sum
            joint = tf.reshape(joint, (-1, np.prod(tensor_dim_sizes)))
            joint = tf.transpose(joint, [1, 0])  # |tensor| x batch_size

            if error_symbol is not None:
                result = tf.unsorted_segment_sum(joint, tf.reshape(tensor, [-1]), output_size + 1)
                # assume error bin is last bin
                result = result[:output_size, :]
            else:
                result = tf.unsorted_segment_sum(joint, tf.reshape(tensor, [-1]), output_size)

            result = tf.transpose(result, [1, 0])
            if not is_batched: result = tf.squeeze(result)
            return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号