def _mean(self):
with ops.control_dependencies(self._assertions):
distribution_means = [d.mean() for d in self.components]
cat_probs = self._cat_probs(log_probs=False)
# This was checked to not be None at construction time.
static_event_rank = self.get_event_shape().ndims
# Expand the rank of x up to static_event_rank times so that
# broadcasting works correctly.
def expand(x):
expanded_x = x
for _ in range(static_event_rank):
expanded_x = array_ops.expand_dims(expanded_x, -1)
return expanded_x
cat_probs = [expand(c_p) for c_p in cat_probs]
partial_means = [
c_p * m for (c_p, m) in zip(cat_probs, distribution_means)
]
# These should all be the same shape by virtue of matching
# batch_shape and event_shape.
return math_ops.add_n(partial_means)
评论列表
文章目录