loss.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:paint_transfer_c92 作者: Hiroshiba 项目源码 文件源码
def make_loss(self, target, raw_line, test):
        xp = self.models.mismatch_discriminator.xp
        batchsize = target.shape[0]
        l_true = xp.ones(batchsize, dtype=numpy.float32)
        l_false = xp.zeros(batchsize, dtype=numpy.float32)

        raw_line_mismatch = chainer.functions.permutate(
            raw_line, indices=numpy.roll(numpy.arange(batchsize, dtype=numpy.int32), shift=1), axis=0)

        output = self.forwarder.forward(
            input=target,
            raw_line=raw_line,
            raw_line_mismatch=raw_line_mismatch,
            test=test,
        )
        generated = output['generated']
        match = output['match']
        mismatch = output['mismatch']
        z = output['z']

        mse = chainer.functions.mean_squared_error(generated, target)
        loss_gen = {'mse': mse}
        chainer.report(loss_gen, self.models.generator)

        match_lsm = utility.chainer.least_square_mean(match, l_false)
        mismatch_lsm = utility.chainer.least_square_mean(mismatch, l_true)
        loss_mismatch_discriminator = {'match_lsm': match_lsm, 'mismatch_lsm': mismatch_lsm}
        chainer.report(loss_mismatch_discriminator, self.models.mismatch_discriminator)

        fake_mismatch_lsm = utility.chainer.least_square_mean(match, l_true)
        z_l2 = chainer.functions.sum(z ** 2) / z.size
        loss_enc = {'mse': mse, 'fake_mismatch_lsm': fake_mismatch_lsm, 'activity_regularization': z_l2}
        chainer.report(loss_enc, self.models.encoder)

        return {
            'encoder': loss_enc,
            'generator': loss_gen,
            'mismatch_discriminator': loss_mismatch_discriminator,
        }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号