def evaluate(model, dataset, params):
with tf.Session(config=tf.ConfigProto(
inter_op_parallelism_threads=params.num_cores,
intra_op_parallelism_threads=params.num_cores,
gpu_options=tf.GPUOptions(allow_growth=True)
)) as session:
tf.local_variables_initializer().run()
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(params.model)
saver.restore(session, ckpt.model_checkpoint_path)
evaluate_retrieval(model, dataset, params, session)
evaluate_loss(model, dataset, params, session)
评论列表
文章目录