def update(gen, dis, optimizer_gen, optimizer_dis, x_batch, margin):
xp = gen.xp
batch_size = len(x_batch)
# from generated image
z = xp.random.normal(0, 1, (batch_size, latent_size)).astype(np.float32)
z = z / (xp.linalg.norm(z, axis=1, keepdims=True) + 1e-12)
x_gen = gen(z)
total_size = np.prod(x_gen.shape)
y_gen, h_gen = dis(x_gen)
h_gen = F.normalize(F.reshape(h_gen, (batch_size, -1)))
similarity = F.sum(F.matmul(h_gen, h_gen, transb=True)) / (batch_size * batch_size)
loss_gen = F.mean_squared_error(x_gen, y_gen) + 0.1 * similarity
loss_dis = F.sum(F.relu(margin * margin - F.batch_l2_norm_squared(x_gen - y_gen))) / total_size
# from real image
x = xp.asarray(x_batch)
y, h = dis(x)
loss_dis += F.mean_squared_error(x, y)
gen.cleargrads()
loss_gen.backward()
optimizer_gen.update()
dis.cleargrads()
loss_dis.backward()
optimizer_dis.update()
return float(loss_gen.data), float(loss_dis.data)
评论列表
文章目录