gmm_ops.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _create_variables(self, data, initial_means=None):
    """Initializes GMM algorithm.

    Args:
      data: a list of Tensors with data, each row is a new example.
      initial_means: a Tensor with a matrix of means.
    """
    first_shard = data[0]
    # Initialize means: num_classes X 1 X dimensions.
    if initial_means is not None:
      self._means = tf.Variable(tf.expand_dims(initial_means, 1),
                                name=self.CLUSTERS_VARIABLE,
                                validate_shape=False, dtype=tf.float32)
    else:
      # Sample data randomly
      self._means = tf.Variable(tf.expand_dims(
          _init_clusters_random(data, self._num_classes, self._random_seed), 1),
                                name=self.CLUSTERS_VARIABLE,
                                validate_shape=False)

    # Initialize covariances.
    if self._covariance_type == FULL_COVARIANCE:
      cov = _covariance(first_shard, False) + self._min_var
      # A matrix per class, num_classes X dimensions X dimensions
      covs = tf.tile(
          tf.expand_dims(cov, 0), [self._num_classes, 1, 1])
    elif self._covariance_type == DIAG_COVARIANCE:
      cov = _covariance(first_shard, True) + self._min_var
      # A diagonal per row, num_classes X dimensions.
      covs = tf.tile(tf.expand_dims(tf.diag_part(cov), 0),
                     [self._num_classes, 1])
    self._covs = tf.Variable(covs, name='clusters_covs', validate_shape=False)
    # Mixture weights, representing the probability that a randomly
    # selected unobservable data (in EM terms) was generated by component k.
    self._alpha = tf.Variable(tf.tile([1.0 / self._num_classes],
                                      [self._num_classes]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号