training.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:FaderNetworks 作者: facebookresearch 项目源码 文件源码
def autoencoder_step(self):
        """
        Train the autoencoder with cross-entropy loss.
        Train the encoder with discriminator loss.
        """
        data = self.data
        params = self.params
        self.ae.train()
        if params.n_lat_dis:
            self.lat_dis.eval()
        if params.n_ptc_dis:
            self.ptc_dis.eval()
        if params.n_clf_dis:
            self.clf_dis.eval()
        bs = params.batch_size
        # batch / encode / decode
        batch_x, batch_y = data.train_batch(bs)
        enc_outputs, dec_outputs = self.ae(batch_x, batch_y)
        # autoencoder loss from reconstruction
        loss = params.lambda_ae * ((batch_x - dec_outputs[-1]) ** 2).mean()
        self.stats['rec_costs'].append(loss.data[0])
        # encoder loss from the latent discriminator
        if params.lambda_lat_dis:
            lat_dis_preds = self.lat_dis(enc_outputs[-1 - params.n_skip])
            lat_dis_loss = get_attr_loss(lat_dis_preds, batch_y, True, params)
            loss = loss + get_lambda(params.lambda_lat_dis, params) * lat_dis_loss
        # decoding with random labels
        if params.lambda_ptc_dis + params.lambda_clf_dis > 0:
            flipped = flip_attributes(batch_y, params, 'all')
            dec_outputs_flipped = self.ae.decode(enc_outputs, flipped)
        # autoencoder loss from the patch discriminator
        if params.lambda_ptc_dis:
            ptc_dis_preds = self.ptc_dis(dec_outputs_flipped[-1])
            y_fake = Variable(torch.FloatTensor(ptc_dis_preds.size())
                                   .fill_(params.smooth_label).cuda())
            ptc_dis_loss = F.binary_cross_entropy(ptc_dis_preds, 1 - y_fake)
            loss = loss + get_lambda(params.lambda_ptc_dis, params) * ptc_dis_loss
        # autoencoder loss from the classifier discriminator
        if params.lambda_clf_dis:
            clf_dis_preds = self.clf_dis(dec_outputs_flipped[-1])
            clf_dis_loss = get_attr_loss(clf_dis_preds, flipped, False, params)
            loss = loss + get_lambda(params.lambda_clf_dis, params) * clf_dis_loss
        # check NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()
        # optimize
        self.ae_optimizer.zero_grad()
        loss.backward()
        if params.clip_grad_norm:
            clip_grad_norm(self.ae.parameters(), params.clip_grad_norm)
        self.ae_optimizer.step()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号