def median_heuristic(y):
""" Estimate RBF bandwith using median heuristic.
Parameters
----------
y : (number of samples, dimension)-ndarray
One row of y corresponds to one sample.
Returns
-------
bandwidth : float
Estimated RBF bandwith.
"""
num_of_samples = y.shape[0] # number of samples
# if y contains more samples, then it is subsampled to this cardinality
num_of_samples_used = 100
# subsample y (if necessary; select '100' random y columns):
if num_of_samples > num_of_samples_used:
idx = choice(num_of_samples, num_of_samples_used, replace=False)
y = y[idx] # broadcasting
dist_vector = pdist(y) # pairwise Euclidean distances
bandwith = median(dist_vector) / sqrt(2)
return bandwith
评论列表
文章目录