def get_eval_df(checkpoint_path="checkpoints/jul1", config='jul1.json'):
hps = hypers.get_default_hparams()
if config:
with open(config) as f:
hps.parse_json(f.read())
hps.is_training = False
hps.batch_size = 1
tf.logging.info('Creating model')
model = rnnmodel.RNNModel(hps)
sess = tf.InteractiveSession()
# Load pretrained weights
tf.logging.info('Loading weights')
utils.load_checkpoint(sess, checkpoint_path)
tf.logging.info('Loading test set')
user_pb = User()
with open('testuser.pb') as f:
user_pb.ParseFromString(f.read())
user = UserWrapper(user_pb)
predictor = pred.RnnModelPredictor(sess, model, .2, predict_nones=0)
df = _get_df(user, predictor)
return df
poke_eval_df.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录