training.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:FaderNetworks 作者: facebookresearch 项目源码 文件源码
def ptc_dis_step(self):
        """
        Train the patch discriminator.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.ptc_dis.train()
        bs = params.batch_size
        # batch / encode / discriminate
        batch_x, batch_y = data.train_batch(bs)
        flipped = flip_attributes(batch_y, params, 'all')
        _, dec_outputs = self.ae(Variable(batch_x.data, volatile=True), flipped)
        real_preds = self.ptc_dis(batch_x)
        fake_preds = self.ptc_dis(Variable(dec_outputs[-1].data))
        y_fake = Variable(torch.FloatTensor(real_preds.size())
                               .fill_(params.smooth_label).cuda())
        # loss / optimize
        loss = F.binary_cross_entropy(real_preds, 1 - y_fake)
        loss += F.binary_cross_entropy(fake_preds, y_fake)
        self.stats['ptc_dis_costs'].append(loss.data[0])
        self.ptc_dis_optimizer.zero_grad()
        loss.backward()
        if params.clip_grad_norm:
            clip_grad_norm(self.ptc_dis.parameters(), params.clip_grad_norm)
        self.ptc_dis_optimizer.step()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号