precompute_probs.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def precompute_probs_for_tag(tag, userfold):
  hps = hypers.hps_for_tag(tag, mode=hypers.Mode.inference)
  tf.logging.info('Creating model')
  dat = BasketDataset(hps, userfold)
  model = rnnmodel.RNNModel(hps, dat)
  sess = tf.InteractiveSession()
  # Load pretrained weights
  tf.logging.info('Loading weights')
  utils.load_checkpoint_for_tag(tag, sess)
  # TODO: deal with 'test mode'
  tf.logging.info('Calculating probabilities')
  probmap = get_probmap(model, sess)
  # Hack because of silly reasons.
  if userfold == 'validation_full':
    userfold = 'validation'
  common.save_pdict_for_tag(tag, probmap, userfold)
  sess.close()
  tf.reset_default_graph()
  return probmap
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号