text_corrector_models.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:deep-text-corrector 作者: atpaino 项目源码 文件源码
def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
    """

    :param encoder_inputs: list of length equal to the input bucket
    length of 1-D tensors (of length equal to the batch size) whose
    elements consist of the token index of each sample in the batch
    at a given index in the input.
    :return:
    """

    def fn_factory(embedding, output_projection=None, update_embedding=True):
        """Get a loop_function that extracts the previous symbol and embeds it.

        Args:
          embedding: embedding tensor for symbols.
          output_projection: None or a pair (W, B). If provided, each fed previous
            output will first be multiplied by W and added B.
          update_embedding: Boolean; if False, the gradients will not propagate
            through the embeddings.

        Returns:
          A loop function.
        """
        def loop_function(prev, _):
            prev = project_and_apply_input_bias(prev, output_projection,
                                                input_bias)

            prev_symbol = math_ops.argmax(prev, 1)
            # Note that gradients will not propagate through the second
            # parameter of embedding_lookup.
            emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
            if not update_embedding:
                emb_prev = array_ops.stop_gradient(emb_prev)
            return emb_prev, prev_symbol
        return loop_function

    return fn_factory
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号