gan.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:liveqa2017 作者: codekansas 项目源码 文件源码
def build_generator(l_var, a_var, embeddings):
    """Builds a question generator model.

    Args:
        l_var: keras tensor, the latent vector input.
        a_var: keras tensor, the answer input.
        embeddings: numpy array, the embeddings to use for knn decoding.
    """

    latent_var = Input(tensor=l_var, name='latent_var_pl')
    answer_var = Input(tensor=a_var, name='gen_answer_pl')
    l_var, a_var = latent_var, answer_var

    RNN_DIMS = 64
    vocab_size, num_embedding_dims = embeddings.shape

    # Computes context of the answer.
    a_lstm = Bidirectional(LSTM(RNN_DIMS, return_sequences=True))
    a_context = a_lstm(a_var)

    # Uses context to formulate a question.
    q_matching_lstm = LSTM(RNN_DIMS, return_sequences=True)
    q_matching_lstm = RecurrentAttention(q_matching_lstm, a_context)
    q_var = q_matching_lstm(l_var)
    q_var = LSTM(RNN_DIMS, return_sequences=True)(q_var)
    q_var = Dense(num_embedding_dims)(q_var)

    # Builds the model from the variables (not compiled).
    model = Model(inputs=[latent_var, answer_var], outputs=[q_var])

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号