def train():
input = Variable(torch.FloatTensor(batch_size, num_history, num_components, image_size, image_size)).cuda()
label = Variable(torch.FloatTensor(batch_size, num_components, image_size, image_size)).cuda()
num_epochs = 25
save_every_iteration = 100
out_dir = "test"
for epoch in range(num_epochs):
for i, data in enumerate(loader):
predictor.zero_grad()
x, y = data
transform = transforms.Compose([
transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
])
x = transform(x)
y = transform(y)
input.data.resize_(x.size()).copy_(x)
label.data.resize_(y.size()).copy_(y)
output = predictor(input)
output_loss = loss(output, label)
output_loss.backward()
predictor.optimizer.step()
print("[" + str(epoch) + "/ " + str(i) + "] Loss: " + str(output_loss.data[0]))
if i % save_every_iteration == 0:
img_outputs = predict_test_sequence()
out_file = '%s/epoch_%03d.png' % (out_dir, epoch)
print("saving to: " + out_file)
vutils.save_image(torch.FloatTensor(initial_outputs), out_dir + "/initial.png")
vutils.save_image(img_outputs.data, out_file)
vutils.save_image(img_outputs.data, out_dir + "/latest.png")
print("Saving model...")
model_out_dir = "."
torch.save(predictor.state_dict(), '%s/model_epoch_%d.pth' % (model_out_dir, epoch))
评论列表
文章目录