def correlation(self, vec_col, method="pearson"):
"""
Compute the correlation matrix for the input dataset of Vectors using the specified method. Method
mapped from pyspark.ml.stat.Correlation.
:param vec_col: The name of the column of vectors for which the correlation coefficient needs to be computed.
This must be a column of the dataset, and it must contain Vector objects.
:param method: String specifying the method to use for computing correlation. Supported: pearson (default),
spearman.
:return: Heatmap plot of the corr matrix using seaborn.
"""
assert isinstance(method, str), "Error, method argument provided must be a string."
assert method == 'pearson' or (
method == 'spearman'), "Error, method only can be 'pearson' or 'sepearman'."
cor = Correlation.corr(self._df, vec_col, method).head()[0].toArray()
return sns.heatmap(cor, mask=np.zeros_like(cor, dtype=np.bool), cmap=sns.diverging_palette(220, 10,
as_cmap=True))
评论列表
文章目录