precompute_probs.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def get_probmap(model, sess):
  """{uid -> {pid -> prob}}"""
  # Start a fresh pass through the validation data
  sess.run(model.dataset.new_epoch_op())
  pmap = defaultdict(dict)
  i = 0
  nseqs = 0
  to_fetch = [model.lastorder_logits, model.dataset['uid'], model.dataset['pid']]
  while 1:
    try:
      final_logits, uids, pids = sess.run(to_fetch)
    except tf.errors.OutOfRangeError:
      break
    batch_size = len(uids)
    nseqs += batch_size
    final_probs = expit(final_logits)
    for uid, pid, prob in zip(uids, pids, final_probs):
      pmap[uid][pid] = prob
    i += 1
  tf.logging.info("Computed probabilities for {} users over {} sequences in {} batches".format(
    len(pmap), nseqs, i
    ))
  return pmap
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号