pix2pix.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pix2pix-pytorch 作者: 1zb 项目源码 文件源码
def train(epoch):
    for batch, (left, right) in enumerate(training_data_loader):
        if args.direction == 'lr':
            input.data.resize_(left.size()).copy_(left)
            target.data.resize_(right.size()).copy_(right)
        else:
            input.data.resize_(right.size()).copy_(right)
            target.data.resize_(left.size()).copy_(left)

        ## Discriminator
        netD.zero_grad()
        # real
        D_real = netD(input, target)
        ones_label.data.resize_(D_real.size()).fill_(1)
        zeros_label.data.resize_(D_real.size()).fill_(0)
        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_x_y = D_real.data.mean()

        # fake
        G_fake = netG(input)
        D_fake = netD(input, G_fake.detach())
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        D_x_gx = D_fake.data.mean()

        D_loss = D_loss_real + D_loss_fake
        D_loss.backward()
        D_solver.step()

        ## Generator
        netG.zero_grad()

        G_fake = netG(input)
        D_fake = netD(input, G_fake)
        D_x_gx_2 = D_fake.data.mean()
        G_loss = F.binary_cross_entropy(D_fake, ones_label) + 100 * F.smooth_l1_loss(G_fake, target)
        G_loss.backward()
        G_solver.step()

        ## debug
        if (batch + 1) % 100 == 0:
            print('[TRAIN] Epoch[{}]({}/{}); D_loss: {:.4f}; G_loss: {:.4f}; D(x): {:.4f} D(G(z)): {:.4f}/{:.4f}'.format(
                epoch, batch + 1, len(training_data_loader), D_loss.data[0], G_loss.data[0], D_x_y, D_x_gx, D_x_gx_2))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号