def train_and_test(challenge, rnn_cell):
'''
????
:return:
'''
train, test = helper.extract_file(challenge)
vocab, word_idx, story_maxlen, query_maxlen = helper.get_vocab(train, test)
vocab_size = len(vocab) + 1 # Reserve 0 for masking via pad_sequences
x, xq, y = helper.vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
tx, txq, ty = helper.vectorize_stories(test, word_idx, story_maxlen, query_maxlen)
with tf.Graph().as_default() as graph:
story_pl, question_pl, answer_pl, dropout_pl = get_placeholder(vocab_size, story_maxlen, query_maxlen)
rnn = model.RNN(rnn_cell, FLAGS.embed_dim, FLAGS.rnn_size, vocab_size)
logits = rnn.inference(story_pl, question_pl, dropout_pl)
loss = rnn.loss(logits, answer_pl)
train_op = rnn.train(loss, FLAGS.init_learning_rate)
correct = rnn.eval(logits, answer_pl)
init = tf.global_variables_initializer()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph=graph) as sess:
# ???????
sess.run(init)
max_test_acc = 0
for i in range(FLAGS.num_epochs):
batch_id = 1
train_gen = helper.generate_data(FLAGS.batch_size, x, xq, y)
for x_batch, xq_batch, y_batch in train_gen:
feed_dict = {story_pl: x_batch, question_pl: xq_batch, answer_pl: y_batch,
dropout_pl: FLAGS.dropout}
cost, _ = sess.run([loss, train_op], feed_dict=feed_dict)
# ?????
# if batch_id % FLAGS.show_every_n_batches == 0:
# print ('Epoch {:>3} Batch {:>4} train_loss = {:.3f}'.format(i, batch_id, cost))
batch_id += 1
# ??epoch??????
test_gen = helper.generate_data(FLAGS.batch_size, tx, txq, ty)
total_correct = 0
total = len(tx)
for tx_batch, txq_batch, ty_batch in test_gen:
feed_dict = {story_pl: tx_batch, question_pl: txq_batch, answer_pl: ty_batch,
dropout_pl: 1.0}
cor = sess.run(correct, feed_dict=feed_dict)
total_correct += int(cor)
acc = total_correct * 1.0 / total
# ??max test accuary
if acc > max_test_acc:
max_test_acc = acc
print (
'Epoch{:>3} train_loss = {:.3f} accuary = {:.3f} max_text_acc = {:.3f}'.format(i, cost, acc,
max_test_acc))
return max_test_acc
评论列表
文章目录