sklearn_intent_classifier.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:rasa_nlu 作者: RasaHQ 项目源码 文件源码
def train(self, training_data, config, **kwargs):
        # type: (TrainingData, RasaNLUConfig, **Any) -> None
        """Train the intent classifier on a data set.

        :param num_threads: number of threads used during training time"""
        from sklearn.model_selection import GridSearchCV
        from sklearn.svm import SVC
        import numpy as np

        labels = [e.get("intent") for e in training_data.intent_examples]

        if len(set(labels)) < 2:
            logger.warn("Can not train an intent classifier. Need at least 2 different classes. " +
                        "Skipping training of intent classifier.")
        else:
            y = self.transform_labels_str2num(labels)
            X = np.stack([example.get("text_features") for example in training_data.intent_examples])

            sklearn_config = config.get("intent_classifier_sklearn")
            C = sklearn_config.get("C", [1, 2, 5, 10, 20, 100])
            kernel = sklearn_config.get("kernel", "linear")
            # dirty str fix because sklearn is expecting str not instance of basestr...
            tuned_parameters = [{"C": C, "kernel": [str(kernel)]}]
            cv_splits = max(2, min(MAX_CV_FOLDS, np.min(np.bincount(y)) // 5))  # aim for 5 examples in each fold

            self.clf = GridSearchCV(SVC(C=1, probability=True, class_weight='balanced'),
                                    param_grid=tuned_parameters, n_jobs=config["num_threads"],
                                    cv=cv_splits, scoring='f1_weighted', verbose=1)

            self.clf.fit(X, y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号