pytorch_run.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pytorch-avitm 作者: hyqneuron 项目源码 文件源码
def train():
    for epoch in xrange(args.num_epoch):
        all_indices = torch.randperm(tensor_tr.size(0)).split(args.batch_size)
        loss_epoch = 0.0
        model.train()                   # switch to training mode
        for batch_indices in all_indices:
            if not args.nogpu: batch_indices = batch_indices.cuda()
            input = Variable(tensor_tr[batch_indices])
            recon, loss = model(input, compute_loss=True)
            # optimize
            optimizer.zero_grad()       # clear previous gradients
            loss.backward()             # backprop
            optimizer.step()            # update parameters
            # report
            loss_epoch += loss.data[0]    # add loss to loss_epoch
        if epoch % 5 == 0:
            print('Epoch {}, loss={}'.format(epoch, loss_epoch / len(all_indices)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号