def predict(self, da):
'''xarray.DataArray version of sklearn.cluster.KMeans.fit.'''
# compatible with the sklean.cluster.KMeans predict method when the input data is not DataArray
if not isinstance(da, xr.DataArray):
return super().predict(da)
# retrieve parameters
n_samples = da.shape[0]
features_shape = da.shape[1:]
n_features = np.prod(features_shape)
X = da.data.reshape(n_samples, n_features)# 'data' might be replaced with 'values'.
# remove NaN values if exists in X
try:
X_valid = X[:, self.valid_features_index_]
except:
X_valid = X
samples_dim = da.dims[0]
samples_coord = {samples_dim: da.coords[samples_dim]}
labels = xr.DataArray(super().predict(X_valid),
dims=samples_dim, coords=samples_coord)
return labels
评论列表
文章目录