cluster.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:xlearn 作者: wy2136 项目源码 文件源码
def predict(self, da):
        '''xarray.DataArray version of sklearn.cluster.KMeans.fit.'''

        # compatible with the sklean.cluster.KMeans predict method when the input data is not DataArray
        if not isinstance(da, xr.DataArray):
            return super().predict(da)

        # retrieve parameters
        n_samples = da.shape[0]
        features_shape = da.shape[1:]
        n_features = np.prod(features_shape)

        X = da.data.reshape(n_samples, n_features)# 'data' might be replaced with 'values'.
        # remove NaN values if exists in X
        try:
            X_valid = X[:, self.valid_features_index_]
        except:
            X_valid = X
        samples_dim = da.dims[0]
        samples_coord = {samples_dim: da.coords[samples_dim]}
        labels = xr.DataArray(super().predict(X_valid),
            dims=samples_dim, coords=samples_coord)

        return labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号