def build_gan(num_latent_dims):
"""Builds a generative adversarial network.
To train the GAN, run the updates on the generator and discriminator model
in a loop.
Args:
num_latent_dims: int, number of latent dimensions in the generator.
"""
embeddings = yahoo.get_word_embeddings()
question_var = Input(shape=(yahoo.QUESTION_MAXLEN,), name='question_var')
answer_var = Input(shape=(yahoo.ANSWER_MAXLEN,), name='answer_var')
latent_var = Input(shape=(num_latent_dims,), name='latent_var')
vocab_size, num_embedding_dims = embeddings.shape
emb = Embedding(vocab_size, num_embedding_dims, weights=[embeddings],
trainable=False)
q_var = emb(question_var)
a_var = emb(answer_var)
l_var = RepeatVector(yahoo.QUESTION_MAXLEN)(latent_var)
# Creates the two models.
gen_model = build_generator(l_var, a_var, embeddings)
real_preds, dis_model = build_discriminator(q_var, a_var)
# Builds the model to train the generator.
dis_model.trainable = False
gen_preds = dis_model([gen_model([l_var, a_var]), a_var])
# Builds the model to train the discriminator.
dis_model.trainable = True
gen_model.trainable = False
fake_preds = dis_model([q_gen, a_var])
# Computes predictions.
preds = pred_model([l_var, a_var])
return gen_model, dis_model
评论列表
文章目录