def train(epoch):
for batch, (left, right) in enumerate(training_data_loader):
if args.direction == 'lr':
input.data.resize_(left.size()).copy_(left)
target.data.resize_(right.size()).copy_(right)
else:
input.data.resize_(right.size()).copy_(right)
target.data.resize_(left.size()).copy_(left)
## Discriminator
netD.zero_grad()
# real
D_real = netD(input, target)
ones_label.data.resize_(D_real.size()).fill_(1)
zeros_label.data.resize_(D_real.size()).fill_(0)
D_loss_real = F.binary_cross_entropy(D_real, ones_label)
D_x_y = D_real.data.mean()
# fake
G_fake = netG(input)
D_fake = netD(input, G_fake.detach())
D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
D_x_gx = D_fake.data.mean()
D_loss = D_loss_real + D_loss_fake
D_loss.backward()
D_solver.step()
## Generator
netG.zero_grad()
G_fake = netG(input)
D_fake = netD(input, G_fake)
D_x_gx_2 = D_fake.data.mean()
G_loss = F.binary_cross_entropy(D_fake, ones_label) + 100 * F.smooth_l1_loss(G_fake, target)
G_loss.backward()
G_solver.step()
## debug
if (batch + 1) % 100 == 0:
print('[TRAIN] Epoch[{}]({}/{}); D_loss: {:.4f}; G_loss: {:.4f}; D(x): {:.4f} D(G(z)): {:.4f}/{:.4f}'.format(
epoch, batch + 1, len(training_data_loader), D_loss.data[0], G_loss.data[0], D_x_y, D_x_gx, D_x_gx_2))
评论列表
文章目录