def tf_bivariate_normal(y, mu, sigma, rho, n_mixtures, batch_size):
mu = tf.verify_tensor_all_finite(mu, "Mu not finite!")
y = tf.verify_tensor_all_finite(y, "Y not finite!")
delta = tf.sub(tf.tile(tf.expand_dims(y, 1), [1, n_mixtures, 1]), mu)
delta = tf.verify_tensor_all_finite(delta, "Delta not finite!")
sigma = tf.verify_tensor_all_finite(sigma, "Sigma not finite!")
s = tf.reduce_prod(sigma, 2)
s = tf.verify_tensor_all_finite(s, "S not finite!")
# -1 <= rho <= 1
z = tf.reduce_sum(tf.square(tf.div(delta, sigma + epsilon) + epsilon), 2) - \
2 * tf.div(tf.mul(rho, tf.reduce_prod(delta, 2)), s + epsilon)
z = tf.verify_tensor_all_finite(z, "Z not finite!")
# 0 < negRho <= 1
rho = tf.verify_tensor_all_finite(rho, "rho in bivariate normal not finite!")
negRho = tf.clip_by_value(1 - tf.square(rho), epsilon, 1.0)
negRho = tf.verify_tensor_all_finite(negRho, "negRho not finite!")
# Note that if negRho goes near zero, or z goes really large, this explodes.
negRho = tf.verify_tensor_all_finite(negRho, "negRho in bivariate normal not finite!")
result = tf.clip_by_value(tf.exp(tf.div(-z, 2 * negRho)), 1.0e-8, 1.0e8)
result = tf.verify_tensor_all_finite(result, "Result in bivariate normal not finite!")
denom = 2 * np.pi * tf.mul(s, tf.sqrt(negRho))
denom = tf.verify_tensor_all_finite(denom, "Denom in bivariate normal not finite!")
result = tf.clip_by_value(tf.div(result, denom + epsilon), epsilon, 1.0)
result = tf.verify_tensor_all_finite(result, "Result2 in bivariate normal not finite!")
return result, delta
评论列表
文章目录