utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Doubly-Stochastic-DGP 作者: ICL-SML 项目源码 文件源码
def normal_sample(mean, var, full_cov=False):
    if full_cov is False:
        z = tf.random_normal(tf.shape(mean), dtype=float_type)
        return mean + z * var ** 0.5
    else:
        S, N, D = shape_as_list(mean) # var is SNND
        mean = tf.transpose(mean, (0, 2, 1))  # SND -> SDN
        var = tf.transpose(var, (0, 3, 1, 2))  # SNND -> SDNN
#        I = jitter * tf.eye(N, dtype=float_type)[None, None, :, :] # 11NN
        chol = tf.cholesky(var)# + I)  # SDNN should be ok without as var already has jitter
        z = tf.random_normal([S, D, N, 1], dtype=float_type)
        f = mean + tf.matmul(chol, z)[:, :, :, 0]  # SDN(1)
        return tf.transpose(f, (0, 2, 1)) # SND
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号