def main(_):
# Save default params and set scope
saved_params = FLAGS.__flags
if saved_params['ensemble']:
model_name = 'ensemble'
elif saved_params['ngram'] == 1:
model_name = 'unigram'
elif saved_params['ngram'] == 2:
model_name = 'bigram'
elif saved_params['ngram'] == 3:
model_name = 'trigram'
else:
assert True, 'Not supported ngram %d'% saved_params['ngram']
model_name += '_embedding' if saved_params['embed'] else '_no_embedding'
saved_params['model_name'] = '%s' % model_name
saved_params['checkpoint_dir'] += model_name
pprint.PrettyPrinter().pprint(saved_params)
saved_dataset = get_data(saved_params)
validation_writer = open(saved_params['valid_result_path'], 'a')
validation_writer.write(model_name + "\n")
validation_writer.write("[dim_hidden, dim_rnn_cell, learning_rate, lstm_dropout, lstm_layer, hidden_dropout, dim_embed]\n")
validation_writer.write("combination\ttop1\ttop5\tepoch\n")
# Run the model
for _ in range(saved_params['valid_iteration']):
# Sample parameter sets
params, combination = sample_parameters(saved_params.copy())
dataset = saved_dataset[:]
# Initialize embeddings
uni_init = get_char2vec(dataset[0][0][:], params['dim_embed_unigram'], dataset[3][0])
bi_init = get_char2vec(dataset[0][1][:], params['dim_embed_bigram'], dataset[3][4])
tri_init = get_char2vec(dataset[0][2][:], params['dim_embed_trigram'], dataset[3][5])
print(model_name, 'Parameter sets: ', end='')
pprint.PrettyPrinter().pprint(combination)
rnn_model = RNN(params, [uni_init, bi_init, tri_init])
top1, top5, ep = experiment(rnn_model, dataset, params)
validation_writer.write(str(combination) + '\t')
validation_writer.write(str(top1) + '\t' + str(top5) + '\tEp:' + str(ep) + '\n')
validation_writer.close()
评论列表
文章目录