gan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:liveqa2017 作者: codekansas 项目源码 文件源码
def build_gan(num_latent_dims):
    """Builds a generative adversarial network.

    To train the GAN, run the updates on the generator and discriminator model
    in a loop.

    Args:
        num_latent_dims: int, number of latent dimensions in the generator.
    """

    embeddings = yahoo.get_word_embeddings()

    question_var = Input(shape=(yahoo.QUESTION_MAXLEN,), name='question_var')
    answer_var = Input(shape=(yahoo.ANSWER_MAXLEN,), name='answer_var')
    latent_var = Input(shape=(num_latent_dims,), name='latent_var')

    vocab_size, num_embedding_dims = embeddings.shape
    emb = Embedding(vocab_size, num_embedding_dims, weights=[embeddings],
                    trainable=False)

    q_var = emb(question_var)
    a_var = emb(answer_var)
    l_var = RepeatVector(yahoo.QUESTION_MAXLEN)(latent_var)

    # Creates the two models.
    gen_model = build_generator(l_var, a_var, embeddings)
    real_preds, dis_model = build_discriminator(q_var, a_var)

    # Builds the model to train the generator.
    dis_model.trainable = False
    gen_preds = dis_model([gen_model([l_var, a_var]), a_var])

    # Builds the model to train the discriminator.
    dis_model.trainable = True
    gen_model.trainable = False
    fake_preds = dis_model([q_gen, a_var])

    # Computes predictions.
    preds = pred_model([l_var, a_var])

    return gen_model, dis_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号