poke_eval_df.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: colinmorris 项目源码 文件源码
def get_eval_df(checkpoint_path="checkpoints/jul1", config='jul1.json'):
  hps = hypers.get_default_hparams()
  if config:
    with open(config) as f:
      hps.parse_json(f.read())
  hps.is_training = False
  hps.batch_size = 1
  tf.logging.info('Creating model')
  model = rnnmodel.RNNModel(hps)

  sess = tf.InteractiveSession()
  # Load pretrained weights
  tf.logging.info('Loading weights')
  utils.load_checkpoint(sess, checkpoint_path)

  tf.logging.info('Loading test set')
  user_pb = User()
  with open('testuser.pb') as f:
    user_pb.ParseFromString(f.read())
  user = UserWrapper(user_pb)

  predictor = pred.RnnModelPredictor(sess, model, .2, predict_nones=0)
  df = _get_df(user, predictor)
  return df
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号