wangbase.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:DiscourseSenser 作者: WladimirSidorenko 项目源码 文件源码
def train(self, a_train_data, a_dev_data=None, a_n_y=-1,
              a_i=-1, a_train_out=None, a_dev_out=None):
        """Method for training the model.

        Args:
          a_train_data (tuple[list, dict]):
            list of training JSON data
          a_dev_data (tuple[list, dict] or None):
            list of development JSON data
          a_n_y (int):
            number of distinct classes
          a_i (int):
            row index for the output predictions
          a_train_out (np.array or None):
            predictions for the training set
          a_dev_out (np.array or None):
            predictions for the training set

        Returns:
          void:

        Note:
          updates ``a_train_out`` and ``a_dev_out`` in place

        """
        self.n_y = a_n_y
        x_train, y_train = self._generate_ts(a_train_data)
        x_dev, y_dev = self._generate_ts(a_dev_data)
        # determine cross-validation and grid-search strategy and fit the model
        if self._gs:
            if a_dev_data is None or not a_dev_data[0]:
                cv = StratifiedKFold(y_train, n_folds=NFOLDS, shuffle=True)
            else:
                cv = self._devset_cv(y_train, len(y_dev), NFOLDS)
                x_train = x_train + x_dev
                y_train = y_train + y_dev
            scorer = make_scorer(f1_score, average="macro")
            self._model = GridSearchCV(self._model, self.PARAM_GRID,
                                       scoring=scorer,
                                       cv=cv, n_jobs=self.N_JOBS, verbose=1)
        self._model.fit([el[-1] for el in x_train], y_train)
        # output best hyper-parameters
        if self._gs:
            print("Best params:", repr(self._model.best_params_),
                  file=sys.stderr)
        if a_i >= 0:
            if a_train_out is not None:
                if self._gs and a_dev_data and a_dev_data[0]:
                    x_train = x_train[:-len(x_dev)]
                for i, x_i in x_train:
                    self._predict(x_i, a_train_out[i], a_i)
            if a_dev_out is not None:
                for i, x_i in x_dev:
                    self._predict(x_i, a_dev_out[i], a_i)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号