def __init__(self, training_file='../res/trump_tweets.txt', model_file='../res/model.pt', n_epochs=1000000,
hidden_size=256, n_layers=2, learning_rate=0.001, chunk_len=140):
self.training_file = training_file
self.model_file = model_file
self.n_epochs = n_epochs
self.hidden_size = hidden_size
self.n_layers = n_layers
self.learning_rate = learning_rate
self.chunk_len = chunk_len
self.file, self.file_len = read_file(training_file)
if os.path.isfile(model_file):
self.decoder = torch.load(model_file)
print('Loaded old model!')
else:
self.decoder = RNN(n_characters, hidden_size, n_characters, n_layers)
print('Constructed new model!')
self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), learning_rate)
self.criterion = nn.CrossEntropyLoss()
self.generator = Generator(self.decoder)
python类RNN的实例源码
def __init__(self, input_size, hidden_size, output_size, n_layers=1, gpu=-1):
self.decoder = RNN(input_size, hidden_size, output_size, n_layers, gpu)
if gpu >= 0:
print("Use GPU %d" % torch.cuda.current_device())
self.decoder.cuda()
self.optimizer = torch.optim.Adam(self.decoder.parameters(), lr=0.01)
self.criterion = nn.CrossEntropyLoss()
def main(_):
# Save default params and set scope
saved_params = FLAGS.__flags
if saved_params['ensemble']:
model_name = 'ensemble'
elif saved_params['ngram'] == 1:
model_name = 'unigram'
elif saved_params['ngram'] == 2:
model_name = 'bigram'
elif saved_params['ngram'] == 3:
model_name = 'trigram'
else:
assert True, 'Not supported ngram %d'% saved_params['ngram']
model_name += '_embedding' if saved_params['embed'] else '_no_embedding'
saved_params['model_name'] = '%s' % model_name
saved_params['checkpoint_dir'] += model_name
pprint.PrettyPrinter().pprint(saved_params)
saved_dataset = get_data(saved_params)
validation_writer = open(saved_params['valid_result_path'], 'a')
validation_writer.write(model_name + "\n")
validation_writer.write("[dim_hidden, dim_rnn_cell, learning_rate, lstm_dropout, lstm_layer, hidden_dropout, dim_embed]\n")
validation_writer.write("combination\ttop1\ttop5\tepoch\n")
# Run the model
for _ in range(saved_params['valid_iteration']):
# Sample parameter sets
params, combination = sample_parameters(saved_params.copy())
dataset = saved_dataset[:]
# Initialize embeddings
uni_init = get_char2vec(dataset[0][0][:], params['dim_embed_unigram'], dataset[3][0])
bi_init = get_char2vec(dataset[0][1][:], params['dim_embed_bigram'], dataset[3][4])
tri_init = get_char2vec(dataset[0][2][:], params['dim_embed_trigram'], dataset[3][5])
print(model_name, 'Parameter sets: ', end='')
pprint.PrettyPrinter().pprint(combination)
rnn_model = RNN(params, [uni_init, bi_init, tri_init])
top1, top5, ep = experiment(rnn_model, dataset, params)
validation_writer.write(str(combination) + '\t')
validation_writer.write(str(top1) + '\t' + str(top5) + '\tEp:' + str(ep) + '\n')
validation_writer.close()
def train_and_test(challenge, rnn_cell):
'''
????
:return:
'''
train, test = helper.extract_file(challenge)
vocab, word_idx, story_maxlen, query_maxlen = helper.get_vocab(train, test)
vocab_size = len(vocab) + 1 # Reserve 0 for masking via pad_sequences
x, xq, y = helper.vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
tx, txq, ty = helper.vectorize_stories(test, word_idx, story_maxlen, query_maxlen)
with tf.Graph().as_default() as graph:
story_pl, question_pl, answer_pl, dropout_pl = get_placeholder(vocab_size, story_maxlen, query_maxlen)
rnn = model.RNN(rnn_cell, FLAGS.embed_dim, FLAGS.rnn_size, vocab_size)
logits = rnn.inference(story_pl, question_pl, dropout_pl)
loss = rnn.loss(logits, answer_pl)
train_op = rnn.train(loss, FLAGS.init_learning_rate)
correct = rnn.eval(logits, answer_pl)
init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph=graph) as sess:
# ???????
sess.run(init)
max_test_acc = 0
for i in range(FLAGS.num_epochs):
batch_id = 1
train_gen = helper.generate_data(FLAGS.batch_size, x, xq, y)
for x_batch, xq_batch, y_batch in train_gen:
feed_dict = {story_pl: x_batch, question_pl: xq_batch, answer_pl: y_batch,
dropout_pl: FLAGS.dropout}
cost, _ = sess.run([loss, train_op], feed_dict=feed_dict)
# ?????
# if batch_id % FLAGS.show_every_n_batches == 0:
# print ('Epoch {:>3} Batch {:>4} train_loss = {:.3f}'.format(i, batch_id, cost))
batch_id += 1
# ??epoch??????
test_gen = helper.generate_data(FLAGS.batch_size, tx, txq, ty)
total_correct = 0
total = len(tx)
for tx_batch, txq_batch, ty_batch in test_gen:
feed_dict = {story_pl: tx_batch, question_pl: txq_batch, answer_pl: ty_batch,
dropout_pl: 1.0}
cor = sess.run(correct, feed_dict=feed_dict)
total_correct += int(cor)
acc = total_correct * 1.0 / total
# ??max test accuary
if acc > max_test_acc:
max_test_acc = acc
print (
'Epoch{:>3} train_loss = {:.3f} accuary = {:.3f} max_text_acc = {:.3f}'.format(i, cost, acc,
max_test_acc))
return max_test_acc