train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:speed 作者: keon 项目源码 文件源码
def validate(models, dataset, arg, cuda=False):
    criterion = nn.MSELoss()
    losses = []
    batcher = dataset.get_batcher(shuffle=True, augment=False)
    for b, (x, y) in enumerate(batcher, 1):
        x = V(th.from_numpy(x).float()).cuda()
        y = V(th.from_numpy(y).float()).cuda()
        # Ensemble average
        logit = None
        for model, _ in models:
            model.eval()
            logit = model(x) if logit is None else logit + model(x)
        logit = th.div(logit, len(models))
        loss = criterion(logit, y)
        losses.append(loss.data[0])
    return np.mean(losses)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号