classification.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyts 作者: johannfaouzi 项目源码 文件源码
def predict(self, X):
        """Predict the class labels for the provided data

        Parameters
        ----------
        X : np.ndarray, shape = [n_samples]

        Returns
        -------
        y : np.array of shape [n_samples]
            Class labels for each data sample.
        """

        if not self.fitted:
            raise NotFittedError("Estimator not fitted, call `fit` before exploiting the model.")

        n_samples = len(X)
        frequencies = np.zeros((n_samples, self.n_all_words_))
        for i in range(n_samples):
            words_unique, words_counts = np.unique(X[i], return_counts=True)
            for j, word in enumerate(self.all_words_):
                if word in words_unique:
                    frequencies[i, j] = words_counts[np.where(words_unique == word)[0]]

        self.frequencies_ = frequencies

        y_pred = cosine_similarity(frequencies, self.tf_idf_array_).argmax(axis=1)

        return y_pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号