def get_probmap(model, sess):
"""{uid -> {pid -> prob}}"""
# Start a fresh pass through the validation data
sess.run(model.dataset.new_epoch_op())
pmap = defaultdict(dict)
i = 0
nseqs = 0
to_fetch = [model.lastorder_logits, model.dataset['uid'], model.dataset['pid']]
while 1:
try:
final_logits, uids, pids = sess.run(to_fetch)
except tf.errors.OutOfRangeError:
break
batch_size = len(uids)
nseqs += batch_size
final_probs = expit(final_logits)
for uid, pid, prob in zip(uids, pids, final_probs):
pmap[uid][pid] = prob
i += 1
tf.logging.info("Computed probabilities for {} users over {} sequences in {} batches".format(
len(pmap), nseqs, i
))
return pmap
precompute_probs.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录