depthdi_pred.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:AMS 作者: EthanTaylor2 项目源码 文件源码
def run_train():
    fout = open('inf.txt','w+')

    test_config = ModelConfig()
    test_config.keep_prob = 1.0
    test_config.batch_size = 1

    Session_config = tf.ConfigProto(allow_soft_placement = True)
    Session_config.gpu_options.allow_growth=True 

    with tf.Graph().as_default(), tf.Session(config=Session_config) as sess:    
        with tf.device('/gpu:0'):
        #if True:
            initializer = tf.random_uniform_initializer(-test_config.init_scale, 
                                                        test_config.init_scale)

            train_model = vgg16.Vgg16(FLAGS.vgg16_file_path)
            train_model.build(initializer)

            data_test = dataset.DataSet(FLAGS.file_path_test,FLAGS.data_root_dir,TEST_SIZE,is_train_set=False)

            test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')

            saver = tf.train.Saver(max_to_keep=100)
            last_epoch = load_model(sess, saver,FLAGS.saveModelPath,train_model)
            print ('start: ',last_epoch + 1)

            test_accury_1,test_accury_5,test_loss = run_epoch(sess,test_config.keep_prob, fout,test_config.batch_size, train_model, data_test, tf.no_op(),2,test_writer,istraining=False) 
            info = "Final: Test accury(top 1): %.4f Test accury(top 5): %.4f Loss %.4f" % (test_accury_1,test_accury_5,test_loss)
            print (info)
            fout.write(info + '\n')
            fout.flush()



            test_writer.close()

            print("Training step is compeleted!") 
            fout.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号