def mog_test(self, embedding, labels, n_genres=None):
"""
Evaluates how well the given embedding performs on the mixture of gaussians task
@param embedding: a list of vectors [v1, v2, ..., vn] of song embeddings
in R^k.
@param labels: a list of genres where labels[i] is the genre of embedding[i].
@param n_genres: the number of genres (for convenience). computed manually if None
is given.
"""
if n_genres == None:
n_genres = len(set(labels))
clf = GaussianMixture(n_components=n_genres)
clf.fit(embedding)
p_labels = clf.predict(embedding)
return self.get_scores(labels, p_labels)
评论列表
文章目录