main.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:visual-question-answering-tensorflow 作者: lmelvix 项目源码 文件源码
def train():
    batch_size = 10
    print "Starting ABC-CNN training"
    vqa = dl.load_questions_answers('data')

    # Create subset of data for over-fitting
    sub_vqa = {}
    sub_vqa['training'] = vqa['training'][:10]
    sub_vqa['validation'] = vqa['validation'][:10]
    sub_vqa['answer_vocab'] = vqa['answer_vocab']
    sub_vqa['question_vocab'] = vqa['question_vocab']
    sub_vqa['max_question_length'] = vqa['max_question_length']

    train_size = len(vqa['training'])
    max_itr = (train_size // batch_size) * 10

    with tf.Session() as sess:
        image, ques, ans, optimizer, loss, accuracy = abc.model(sess, batch_size)
        print "Defined ABC model"

        train_loader = util.get_batch(sess, vqa, batch_size, 'training')
        print "Created train dataset generator"

        valid_loader = util.get_batch(sess, vqa, batch_size, 'validation')
        print "Created validation dataset generator"

        writer = abc.write_tensorboard(sess)
        init = tf.global_variables_initializer()        
        merged = tf.summary.merge_all()
        sess.run(init)
        print "Initialized Tensor variables"

        itr = 1

        while itr < max_itr:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

            _, vgg_batch, ques_batch, answer_batch = train_loader.next()
            _, valid_vgg_batch, valid_ques_batch, valid_answer_batch = valid_loader.next() 
            sess.run(optimizer, feed_dict={image: vgg_batch, ques: ques_batch, ans: answer_batch})
            [train_summary, train_loss, train_accuracy] = sess.run([merged, loss, accuracy], 
                                                    feed_dict={image: vgg_batch, ques: ques_batch, ans: answer_batch},
                                                    options=run_options,
                                                    run_metadata=run_metadata)
            [valid_loss, valid_accuracy] = sess.run([loss, accuracy],
                                                    feed_dict={image: valid_vgg_batch, 
                                                    ques: valid_ques_batch, 
                                                    ans: valid_answer_batch})

            writer.add_run_metadata(run_metadata, 'step%03d' % itr)
            writer.add_summary(train_summary, itr)
            writer.flush()
            print "Iteration:%d\tTraining Loss:%f\tTraining Accuracy:%f\tValidation Loss:%f\tValidation Accuracy:%f"%(
                itr, train_loss, 100.*train_accuracy, valid_loss, 100.*valid_accuracy)
            itr += 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号