terpret_tf_log_runtime.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def apply_factor(tensor, *args, **kwargs):
    scope = kwargs.pop("scope", "")     
    with tf.name_scope(scope):
        n_args = len(args)

        if n_args is 0:
            tensor, output_size, error_symbol = tensor
            return one_hot(tensor, output_size, scope=scope)
        else:
            tensor, args = slice_out_int_literals(tensor, list(args))
            args, is_batched = make_batch_consistent(args)
            tensor, output_size, error_symbol = tensor

            # handle the case where all arguments were int literals
            tensor_dim_sizes = [dim.value for dim in tensor.get_shape()]
            if not tensor_dim_sizes:
                return one_hot(tensor, output_size, scope=scope)

            # Each arg is batch size x arg dim. Add dimensions to enable broadcasting.
            for i, arg in enumerate(args):
                for j in range(len(args)):
                    if j == i: continue
                    args[i] = tf.expand_dims(args[i], j + 1)

            # compute joint before tensor is applied
            joint = 0
            for arg in args:
                joint = joint + arg

            # prepare for unsorted_segment_sum
            joint = tf.reshape(joint, (-1, np.prod(tensor_dim_sizes)))
            joint = tf.transpose(joint, [1, 0])  # |tensor| x batch_size

            flat_tensor = tf.reshape(tensor, [-1])
            if error_symbol is not None:
                to_logsumexp = tf.dynamic_partition(joint, flat_tensor, output_size + 1)
                del to_logsumexp[error_symbol]
            else:
                to_logsumexp = tf.dynamic_partition(joint, flat_tensor, output_size)



            result = tf.pack(
                        map(lambda x : logsumexp(x, reduction_indices=0), to_logsumexp)
                    )

            result = tf.transpose(result, [1, 0])
            if not is_batched: result = tf.squeeze(result)
            return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号