predict.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:speed 作者: keon 项目源码 文件源码
def predict(models, dataset, arg, cuda=False):
    prediction_file = open('save/predictions.txt', 'w')
    batcher = dataset.get_batcher(shuffle=False, augment=False)
    for b, (x, _) in enumerate(batcher, 1):
        x = V(th.from_numpy(x).float()).cuda()
        # Ensemble average
        logit = None
        for model, _ in models:
            model.eval()
            logit = model(x) if logit is None else logit + model(x)
        logit = th.div(logit, len(models))
        prediction = logit.cpu().data[0][0]
        prediction_file.write('%s\n' % prediction)
        if arg.verbose and b % 100 == 0:
            print('[predict] [b]:%s - prediction: %s' % (b, prediction))
    # prediction_file.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号