def get_model(
data_path, #Path to dataset
hid_dim, #Dimension of the hidden GRU layers
optimizer='rmsprop', #Optimization function to be used
loss='categorical_crossentropy' #Loss function to be used
):
metadata_dict = {}
f = open(os.path.join(data_path, 'metadata', 'metadata.txt'), 'r')
for line in f:
entry = line.split(':')
metadata_dict[entry[0]] = int(entry[1])
f.close()
story_maxlen = metadata_dict['input_length']
query_maxlen = metadata_dict['query_length']
vocab_size = metadata_dict['vocab_size']
entity_dim = metadata_dict['entity_dim']
embed_weights = np.load(os.path.join(data_path, 'metadata', 'weights.npy'))
word_dim = embed_weights.shape[1]
########## MODEL ############
story_input = Input(shape=(story_maxlen,), dtype='int32', name="StoryInput")
x = Embedding(input_dim=vocab_size+2,
output_dim=word_dim,
input_length=story_maxlen,
mask_zero=True,
weights=[embed_weights])(story_input)
query_input = Input(shape=(query_maxlen,), dtype='int32', name='QueryInput')
x_q = Embedding(input_dim=vocab_size+2,
output_dim=word_dim,
input_length=query_maxlen,
mask_zero=True,
weights=[embed_weights])(query_input)
concat_embeddings = masked_concat([x_q, x], concat_axis=1)
lstm = GRU(hid_dim, consume_less='gpu')(concat_embeddings)
reverse_lstm = GRU(hid_dim, consume_less='gpu', go_backwards=True)(concat_embeddings)
merged = merge([lstm, reverse_lstm], mode='concat')
result = Dense(entity_dim, activation='softmax')(merged)
model = Model(input=[story_input, query_input], output=result)
model.compile(optimizer=optimizer,
loss=loss,
metrics=['accuracy'])
print(model.summary())
return model
simple_gru_model.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录