p1_seq2seq.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def extract_argmax_and_embed(embedding, output_projection=None):
    """
    Get a loop_function that extracts the previous symbol and embeds it. Used by decoder.
    :param embedding: embedding tensor for symbol
    :param output_projection: None or a pair (W, B). If provided, each fed previous output will
    first be multiplied by W and added B.
    :return: A loop function
    """
    def loop_function(prev, _):
        if output_projection is not None:
            prev = tf.matmul(prev, output_projection[0]) + output_projection[1]
        prev_symbol = tf.argmax(prev, 1) #?????INDEX
        emb_prev = tf.gather(embedding, prev_symbol) #????INDEX???embedding
        return emb_prev
    return loop_function

# RNN??????
# ???????????????????test,?t???????t+1???s??
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号