def make_loss(self, target, raw_line, test):
xp = self.models.mismatch_discriminator.xp
batchsize = target.shape[0]
l_true = xp.ones(batchsize, dtype=numpy.float32)
l_false = xp.zeros(batchsize, dtype=numpy.float32)
raw_line_mismatch = chainer.functions.permutate(
raw_line, indices=numpy.roll(numpy.arange(batchsize, dtype=numpy.int32), shift=1), axis=0)
output = self.forwarder.forward(
input=target,
raw_line=raw_line,
raw_line_mismatch=raw_line_mismatch,
test=test,
)
generated = output['generated']
match = output['match']
mismatch = output['mismatch']
z = output['z']
mse = chainer.functions.mean_squared_error(generated, target)
loss_gen = {'mse': mse}
chainer.report(loss_gen, self.models.generator)
match_lsm = utility.chainer.least_square_mean(match, l_false)
mismatch_lsm = utility.chainer.least_square_mean(mismatch, l_true)
loss_mismatch_discriminator = {'match_lsm': match_lsm, 'mismatch_lsm': mismatch_lsm}
chainer.report(loss_mismatch_discriminator, self.models.mismatch_discriminator)
fake_mismatch_lsm = utility.chainer.least_square_mean(match, l_true)
z_l2 = chainer.functions.sum(z ** 2) / z.size
loss_enc = {'mse': mse, 'fake_mismatch_lsm': fake_mismatch_lsm, 'activity_regularization': z_l2}
chainer.report(loss_enc, self.models.encoder)
return {
'encoder': loss_enc,
'generator': loss_gen,
'mismatch_discriminator': loss_mismatch_discriminator,
}
评论列表
文章目录