gmm_ops.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _define_partial_maximization_operation(self, shard_id, shard):
    """Computes the partial statistics of the means and covariances.

    Args:
      shard_id: current shard id.
      shard: current data shard, 1 X num_examples X dimensions.
    """
    # Soft assignment of each data point to each of the two clusters.
    self._points_in_k[shard_id] = tf.reduce_sum(self._w[shard_id], 0,
                                                keep_dims=True)
    # Partial means.
    w_mul_x = tf.expand_dims(
        tf.matmul(self._w[shard_id],
                  tf.squeeze(shard, [0]), transpose_a=True), 1)
    self._w_mul_x.append(w_mul_x)
    # Partial covariances.
    x = tf.concat(0, [shard for _ in range(self._num_classes)])
    x_trans = tf.transpose(x, perm=[0, 2, 1])
    x_mul_w = tf.concat(0, [
        tf.expand_dims(x_trans[k, :, :] * self._w[shard_id][:, k], 0)
        for k in range(self._num_classes)])
    self._w_mul_x2.append(tf.batch_matmul(x_mul_w, x))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号