def multisample_conditional(self, X, full_cov=False):
if full_cov is True:
# this is unlikely to be called in a performance critical application, so we use
# this clear but slow implementation
f = lambda a: self.conditional(a, full_cov=full_cov)
mean, var = tf.map_fn(f, X, dtype=(tf.float64, tf.float64))
return tf.stack(mean), tf.stack(var)
else:
# this should be faster as only computes the Z_uu once, but could be made faster
# still perhaps by avoiding reshaping (but need to rewrite conditional)
S, N, D = shape_as_list(X)
X_flat = tf.reshape(X, [S*N, D])
mean, var = self.conditional(X_flat)
return [tf.reshape(m, [S, N, -1]) for m in [mean, var]]
评论列表
文章目录