def train():
for epoch in xrange(args.num_epoch):
all_indices = torch.randperm(tensor_tr.size(0)).split(args.batch_size)
loss_epoch = 0.0
model.train() # switch to training mode
for batch_indices in all_indices:
if not args.nogpu: batch_indices = batch_indices.cuda()
input = Variable(tensor_tr[batch_indices])
recon, loss = model(input, compute_loss=True)
# optimize
optimizer.zero_grad() # clear previous gradients
loss.backward() # backprop
optimizer.step() # update parameters
# report
loss_epoch += loss.data[0] # add loss to loss_epoch
if epoch % 5 == 0:
print('Epoch {}, loss={}'.format(epoch, loss_epoch / len(all_indices)))
评论列表
文章目录