main.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:scientific-paper-summarisation 作者: EdCo95 项目源码 文件源码
def main():

    # Create the TensorFlow placeholders
    placeholders = create_placeholders()

    # Get the training feed dicts and define the length of the test set.
    train_feed_dicts, vocab = load_data(placeholders)
    num_test = int(len(train_feed_dicts) * (1 / 5))

    print("Number of Feed Dicts: ", len(train_feed_dicts))
    print("Number of Test Dicts: ", num_test)

    # Slice the dictionary list into training and test sets
    final_test_feed_dicts = train_feed_dicts[0:num_test]
    test_feed_dicts = train_feed_dicts[0:50]
    train_feed_dicts = train_feed_dicts[num_test:]

    # Do not take up all the GPU memory, all the time.
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True

    with tf.Session(config=sess_config) as sess:

        logits, loss, preds, accuracy, saver = train(placeholders, train_feed_dicts, test_feed_dicts, vocab, sess=sess)
        print('============')

        # Test on train data - later, test on test data
        avg_acc = 0
        count = 0
        for j, batch in enumerate(final_test_feed_dicts):
            acc = sess.run(accuracy, feed_dict=batch)
            print("Accuracy on test set is: ", acc)
            avg_acc += acc
            count += 1
            print('-----')

        print("Overall Average Accuracy on the Test Set Is: ", avg_acc / count)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号