mnist.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:studio 作者: studioml 项目源码 文件源码
def train(epoch, reporter):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        reporter.record(batch_idx, loss=loss.data[0])
        reporter.report()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号