basic.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch.snapshot.ensembles 作者: moskomule 项目源码 文件源码
def train_normal(model, epochs, vis=None):

    optimizer = optim.Adam(model.parameters())
    _lr_list, _loss_list = [], []
    for epoch in range(epochs):
        _epoch_loss = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            if cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            _epoch_loss += loss.data[0] / len(train_loader)
            loss.backward()
            optimizer.step()

        _loss_list.append(_epoch_loss)
        _lr_list.append(optimizer.state_dict()["param_groups"][0]["lr"])

        if vis is not None and epoch % 10 == 0:
            vis.line(np.array(_lr_list), np.arange(epoch+1), win="lr_n",
                     opts=dict(title="learning rate",
                               xlabel="epochs",
                               ylabel="learning rate (normal)"))
            vis.line(np.array(_loss_list), np.arange(epoch+1), win="loss_n",
                     opts=dict(title="loss",
                               xlabel="epochs",
                               ylabel="training loss (normal)"))

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号