def _sample(self, *params, **kwargs):
""" Returns the mean of the most probable Gaussian distribution """
# Get identifier of the most probable component
mixings, sigma, mean = params
batch_size = mixings.get_shape()[0]
id_mix = tf.cast(tf.argmax(mixings, axis=1), tf.int32)
# Extracted from https://github.com/tensorflow/tensorflow/issues/418
# Get mean of corresponding component
sample = tf.gather(
params=tf.reshape(mean, [-1]),
indices=tf.range(batch_size) * tf.shape(mean)[1] + id_mix
)
# Small workaround
if sample.get_shape().ndims < 2:
sample = tf.expand_dims(sample, axis=1)
return sample
评论列表
文章目录