def ptc_dis_step(self):
"""
Train the patch discriminator.
"""
data = self.data
params = self.params
self.ae.eval()
self.ptc_dis.train()
bs = params.batch_size
# batch / encode / discriminate
batch_x, batch_y = data.train_batch(bs)
flipped = flip_attributes(batch_y, params, 'all')
_, dec_outputs = self.ae(Variable(batch_x.data, volatile=True), flipped)
real_preds = self.ptc_dis(batch_x)
fake_preds = self.ptc_dis(Variable(dec_outputs[-1].data))
y_fake = Variable(torch.FloatTensor(real_preds.size())
.fill_(params.smooth_label).cuda())
# loss / optimize
loss = F.binary_cross_entropy(real_preds, 1 - y_fake)
loss += F.binary_cross_entropy(fake_preds, y_fake)
self.stats['ptc_dis_costs'].append(loss.data[0])
self.ptc_dis_optimizer.zero_grad()
loss.backward()
if params.clip_grad_norm:
clip_grad_norm(self.ptc_dis.parameters(), params.clip_grad_norm)
self.ptc_dis_optimizer.step()
评论列表
文章目录