train_and_test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:deepnlp-models 作者: ByzenMa 项目源码 文件源码
def train_and_test(challenge, rnn_cell):
    '''
    ????
    :return:
    '''
    train, test = helper.extract_file(challenge)
    vocab, word_idx, story_maxlen, query_maxlen = helper.get_vocab(train, test)
    vocab_size = len(vocab) + 1  # Reserve 0 for masking via pad_sequences
    x, xq, y = helper.vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
    tx, txq, ty = helper.vectorize_stories(test, word_idx, story_maxlen, query_maxlen)
    with tf.Graph().as_default() as graph:
        story_pl, question_pl, answer_pl, dropout_pl = get_placeholder(vocab_size, story_maxlen, query_maxlen)
        rnn = model.RNN(rnn_cell, FLAGS.embed_dim, FLAGS.rnn_size, vocab_size)
        logits = rnn.inference(story_pl, question_pl, dropout_pl)
        loss = rnn.loss(logits, answer_pl)
        train_op = rnn.train(loss, FLAGS.init_learning_rate)
        correct = rnn.eval(logits, answer_pl)
        init = tf.global_variables_initializer()
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), graph=graph) as sess:
            # ???????
            sess.run(init)
            max_test_acc = 0
            for i in range(FLAGS.num_epochs):
                batch_id = 1
                train_gen = helper.generate_data(FLAGS.batch_size, x, xq, y)
                for x_batch, xq_batch, y_batch in train_gen:
                    feed_dict = {story_pl: x_batch, question_pl: xq_batch, answer_pl: y_batch,
                                 dropout_pl: FLAGS.dropout}
                    cost, _ = sess.run([loss, train_op], feed_dict=feed_dict)
                    # ?????
                    # if batch_id % FLAGS.show_every_n_batches == 0:
                    #     print ('Epoch {:>3} Batch {:>4}   train_loss = {:.3f}'.format(i, batch_id, cost))
                    batch_id += 1
                # ??epoch??????
                test_gen = helper.generate_data(FLAGS.batch_size, tx, txq, ty)
                total_correct = 0
                total = len(tx)
                for tx_batch, txq_batch, ty_batch in test_gen:
                    feed_dict = {story_pl: tx_batch, question_pl: txq_batch, answer_pl: ty_batch,
                                 dropout_pl: 1.0}
                    cor = sess.run(correct, feed_dict=feed_dict)
                    total_correct += int(cor)
                acc = total_correct * 1.0 / total
                # ??max test accuary
                if acc > max_test_acc:
                    max_test_acc = acc
                print (
                    'Epoch{:>3}   train_loss = {:.3f}   accuary = {:.3f}   max_text_acc = {:.3f}'.format(i, cost, acc,
                                                                                                         max_test_acc))
            return max_test_acc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号