mixture.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _mean(self):
    with ops.control_dependencies(self._assertions):
      distribution_means = [d.mean() for d in self.components]
      cat_probs = self._cat_probs(log_probs=False)
      # This was checked to not be None at construction time.
      static_event_rank = self.get_event_shape().ndims
      # Expand the rank of x up to static_event_rank times so that
      # broadcasting works correctly.
      def expand(x):
        expanded_x = x
        for _ in range(static_event_rank):
          expanded_x = array_ops.expand_dims(expanded_x, -1)
        return expanded_x
      cat_probs = [expand(c_p) for c_p in cat_probs]
      partial_means = [
          c_p * m for (c_p, m) in zip(cat_probs, distribution_means)
      ]
      # These should all be the same shape by virtue of matching
      # batch_shape and event_shape.
      return math_ops.add_n(partial_means)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号