def loss(x64, x_tilde, z_x_log_sigma_sq1, z_x_meanx1, d_x, d_x_p, l_x, l_x_tilde,ss_ ):
SSE_loss = tf.reduce_mean(tf.square(x64 - x_tilde))
pair_loss=tf.reduce_mean(tf.square(tf.matmul(z_x_meanx1, tf.transpose(z_x_meanx1))- ss_)) +\
tf.reduce_mean(tf.square(z_x_meanx1 - tf.sign(z_x_meanx1)))
KL_loss = tf.reduce_sum(-0.5 * tf.reduce_sum(1 + tf.clip_by_value(z_x_log_sigma_sq1, -10.0, 10.0)
- tf.square(tf.clip_by_value(z_x_meanx1, -10.0, 10.0))
- tf.exp(tf.clip_by_value(z_x_log_sigma_sq1, -10.0, 10.0)),
1)) / 64/64/3
D_loss = tf.reduce_mean(-1. * (tf.log(tf.clip_by_value(d_x, 1e-5, 1.0)) +
tf.log(tf.clip_by_value(1.0 - d_x_p, 1e-5, 1.0))))
G_loss = tf.reduce_mean(-1. * (tf.log(tf.clip_by_value(d_x_p, 1e-5, 1.0))))
LL_loss = tf.reduce_sum(tf.square(l_x - l_x_tilde)) / 64/64./3.
return SSE_loss, KL_loss, D_loss, G_loss, LL_loss,pair_loss
评论列表
文章目录