def build_gan(num_latent_dims):
"""Builds a generative adversarial network.
To train the GAN, run the updates on the generator and discriminator model
in a loop.
Args:
num_latent_dims: int, number of latent dimensions in the generator.
"""
embeddings = yahoo.get_word_embeddings()
question_var = Input(shape=(yahoo.QUESTION_MAXLEN,), name='question_var')
answer_var = Input(shape=(yahoo.ANSWER_MAXLEN,), name='answer_var')
latent_var = Input(shape=(num_latent_dims,), name='latent_var')
vocab_size, num_embedding_dims = embeddings.shape
emb = Embedding(vocab_size, num_embedding_dims, weights=[embeddings],
trainable=False)
q_var = emb(question_var)
a_var = emb(answer_var)
l_var = RepeatVector(yahoo.QUESTION_MAXLEN)(latent_var)
# Creates the two models.
gen_model = build_generator(l_var, a_var, embeddings)
real_preds, dis_model = build_discriminator(q_var, a_var)
# Builds the model to train the generator.
dis_model.trainable = False
gen_preds = dis_model([gen_model([l_var, a_var]), a_var])
# Builds the model to train the discriminator.
dis_model.trainable = True
gen_model.trainable = False
fake_preds = dis_model([q_gen, a_var])
# Computes predictions.
preds = pred_model([l_var, a_var])
return gen_model, dis_model
python类RepeatVector()的实例源码
def create_model(self, ret_model = False):
#base_model = VGG16(weights='imagenet', include_top=False, input_shape = (224, 224, 3))
#base_model.trainable=False
image_model = Sequential()
#image_model.add(base_model)
#image_model.add(Flatten())
image_model.add(Dense(EMBEDDING_DIM, input_dim = 4096, activation='relu'))
image_model.add(RepeatVector(self.max_cap_len))
lang_model = Sequential()
lang_model.add(Embedding(self.vocab_size, 256, input_length=self.max_cap_len))
lang_model.add(LSTM(256,return_sequences=True))
lang_model.add(TimeDistributed(Dense(EMBEDDING_DIM)))
model = Sequential()
model.add(Merge([image_model, lang_model], mode='concat'))
model.add(LSTM(1000,return_sequences=False))
model.add(Dense(self.vocab_size))
model.add(Activation('softmax'))
print "Model created!"
if(ret_model==True):
return model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
return model
def build_model(text_len, negative_size, optimizer, word_size, entity_size,
dim_size, word_static, entity_static, word_embedding, entity_embedding):
text_input_layer = Input(shape=(text_len,), dtype='int32')
word_embed_layer = Embedding(
word_size, dim_size, input_length=text_len, name='word_embedding',
weights=[word_embedding], trainable=not word_static
)(text_input_layer)
text_layer = TextRepresentationLayer(name='text_layer')(
[word_embed_layer, text_input_layer]
)
entity_input_layer = Input(shape=(negative_size + 1,), dtype='int32')
entity_embed_layer = Embedding(
entity_size, dim_size, input_length=negative_size + 1,
name='entity_embedding', weights=[entity_embedding],
trainable=not entity_static
)(entity_input_layer)
similarity_layer = DotLayer(name='dot_layer')(
[RepeatVector(negative_size + 1)(text_layer), entity_embed_layer]
)
predictions = SoftmaxLayer()(similarity_layer)
model = Model(input=[text_input_layer, entity_input_layer],
output=predictions)
model.compile(optimizer=optimizer, loss='categorical_crossentropy',
metrics=['accuracy'])
return model