def vae_loss(x, x_hat):
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
xent_loss = n * objectives.binary_crossentropy(x, x_hat)
mse_loss = n * objectives.mse(x, x_hat)
if use_loss == 'xent':
return xent_loss + kl_loss
elif use_loss == 'mse':
return mse_loss + kl_loss
else:
raise Expception, 'Nonknow loss!'
评论列表
文章目录