def normal_sample(mean, var, full_cov=False):
if full_cov is False:
z = tf.random_normal(tf.shape(mean), dtype=float_type)
return mean + z * var ** 0.5
else:
S, N, D = shape_as_list(mean) # var is SNND
mean = tf.transpose(mean, (0, 2, 1)) # SND -> SDN
var = tf.transpose(var, (0, 3, 1, 2)) # SNND -> SDNN
# I = jitter * tf.eye(N, dtype=float_type)[None, None, :, :] # 11NN
chol = tf.cholesky(var)# + I) # SDNN should be ok without as var already has jitter
z = tf.random_normal([S, D, N, 1], dtype=float_type)
f = mean + tf.matmul(chol, z)[:, :, :, 0] # SDN(1)
return tf.transpose(f, (0, 2, 1)) # SND
评论列表
文章目录