train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:ResNeXt.pytorch 作者: prlz77 项目源码 文件源码
def train():
        net.train()
        loss_avg = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = torch.autograd.Variable(data.cuda()), torch.autograd.Variable(target.cuda())

            # forward
            output = net(data)

            # backward
            optimizer.zero_grad()
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            # exponential moving average
            loss_avg = loss_avg * 0.2 + loss.data[0] * 0.8
        state['train_loss'] = loss_avg


    # test function (forward only)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号