train.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:speed 作者: keon 项目源码 文件源码
def train(e, model, opt, dataset, arg, cuda=False):
    model.train()
    criterion = nn.MSELoss()
    losses = []

    batcher = dataset.get_batcher(shuffle=True, augment=True)
    for b, (x, y) in enumerate(batcher, 1):
        x = V(th.from_numpy(x).float()).cuda()
        y = V(th.from_numpy(y).float()).cuda()
        opt.zero_grad()
        logit = model(x)
        loss = criterion(logit, y)
        loss.backward()
        opt.step()

        losses.append(loss.data[0])
        if arg.verbose and b % 50 == 0:
            loss_t = np.mean(losses[:-49])
            print('[train] [e]:%s [b]:%s - [loss]:%s' % (e, b, loss_t))
    return losses
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号