bounce.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:MachineLearning 作者: timomernick 项目源码 文件源码
def train():
    input = Variable(torch.FloatTensor(batch_size, num_history, num_components, image_size, image_size)).cuda()
    label = Variable(torch.FloatTensor(batch_size, num_components, image_size, image_size)).cuda()

    num_epochs = 25
    save_every_iteration = 100
    out_dir = "test"
    for epoch in range(num_epochs):
        for i, data in enumerate(loader):
            predictor.zero_grad()
            x, y = data

            transform = transforms.Compose([
                transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
                ])
            x = transform(x)
            y = transform(y)

            input.data.resize_(x.size()).copy_(x)
            label.data.resize_(y.size()).copy_(y)

            output = predictor(input)
            output_loss = loss(output, label)
            output_loss.backward()
            predictor.optimizer.step()

            print("[" + str(epoch) + "/ " + str(i) + "] Loss: " + str(output_loss.data[0]))

            if i % save_every_iteration == 0:
                img_outputs = predict_test_sequence()
                out_file = '%s/epoch_%03d.png' % (out_dir, epoch)
                print("saving to: " + out_file)
                vutils.save_image(torch.FloatTensor(initial_outputs), out_dir + "/initial.png")
                vutils.save_image(img_outputs.data, out_file)
                vutils.save_image(img_outputs.data, out_dir + "/latest.png")

        print("Saving model...")
        model_out_dir = "."
        torch.save(predictor.state_dict(), '%s/model_epoch_%d.pth' % (model_out_dir, epoch))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号