def precompute_probs_for_tag(tag, userfold):
hps = hypers.hps_for_tag(tag, mode=hypers.Mode.inference)
tf.logging.info('Creating model')
dat = BasketDataset(hps, userfold)
model = rnnmodel.RNNModel(hps, dat)
sess = tf.InteractiveSession()
# Load pretrained weights
tf.logging.info('Loading weights')
utils.load_checkpoint_for_tag(tag, sess)
# TODO: deal with 'test mode'
tf.logging.info('Calculating probabilities')
probmap = get_probmap(model, sess)
# Hack because of silly reasons.
if userfold == 'validation_full':
userfold = 'validation'
common.save_pdict_for_tag(tag, probmap, userfold)
sess.close()
tf.reset_default_graph()
return probmap
precompute_probs.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录