def autoencoder_step(self):
"""
Train the autoencoder with cross-entropy loss.
Train the encoder with discriminator loss.
"""
data = self.data
params = self.params
self.ae.train()
if params.n_lat_dis:
self.lat_dis.eval()
if params.n_ptc_dis:
self.ptc_dis.eval()
if params.n_clf_dis:
self.clf_dis.eval()
bs = params.batch_size
# batch / encode / decode
batch_x, batch_y = data.train_batch(bs)
enc_outputs, dec_outputs = self.ae(batch_x, batch_y)
# autoencoder loss from reconstruction
loss = params.lambda_ae * ((batch_x - dec_outputs[-1]) ** 2).mean()
self.stats['rec_costs'].append(loss.data[0])
# encoder loss from the latent discriminator
if params.lambda_lat_dis:
lat_dis_preds = self.lat_dis(enc_outputs[-1 - params.n_skip])
lat_dis_loss = get_attr_loss(lat_dis_preds, batch_y, True, params)
loss = loss + get_lambda(params.lambda_lat_dis, params) * lat_dis_loss
# decoding with random labels
if params.lambda_ptc_dis + params.lambda_clf_dis > 0:
flipped = flip_attributes(batch_y, params, 'all')
dec_outputs_flipped = self.ae.decode(enc_outputs, flipped)
# autoencoder loss from the patch discriminator
if params.lambda_ptc_dis:
ptc_dis_preds = self.ptc_dis(dec_outputs_flipped[-1])
y_fake = Variable(torch.FloatTensor(ptc_dis_preds.size())
.fill_(params.smooth_label).cuda())
ptc_dis_loss = F.binary_cross_entropy(ptc_dis_preds, 1 - y_fake)
loss = loss + get_lambda(params.lambda_ptc_dis, params) * ptc_dis_loss
# autoencoder loss from the classifier discriminator
if params.lambda_clf_dis:
clf_dis_preds = self.clf_dis(dec_outputs_flipped[-1])
clf_dis_loss = get_attr_loss(clf_dis_preds, flipped, False, params)
loss = loss + get_lambda(params.lambda_clf_dis, params) * clf_dis_loss
# check NaN
if (loss != loss).data.any():
logger.error("NaN detected")
exit()
# optimize
self.ae_optimizer.zero_grad()
loss.backward()
if params.clip_grad_norm:
clip_grad_norm(self.ae.parameters(), params.clip_grad_norm)
self.ae_optimizer.step()
评论列表
文章目录