models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:uncover-ml 作者: GeoscienceAustralia 项目源码 文件源码
def fit(self, x, y, *args, **kwargs):

        # set a different random seed for each thread
        np.random.seed(self.random_state + mpiops.chunk_index)

        if self.parallel:
            process_rfs = np.array_split(range(self.forests),
                                         mpiops.chunks)[mpiops.chunk_index]
        else:
            process_rfs = range(self.forests)

        for t in process_rfs:
            print('training forest {} using '
                  'process {}'.format(t, mpiops.chunk_index))

            # change random state in each forest
            self.kwargs['random_state'] = np.random.randint(0, 10000)
            rf = RandomForestTransformed(
                target_transform=self.target_transform,
                n_estimators=self.n_estimators,
                **self.kwargs
                )
            rf.fit(x, y)
            if self.parallel:  # used in training
                pk_f = join(self.temp_dir,
                            'rf_model_{}.pk'.format(t))
            else:  # used when parallel is false, i.e., during x-val
                pk_f = join(self.temp_dir,
                            'rf_model_{}_{}.pk'.format(t, mpiops.chunk_index))
            with open(pk_f, 'wb') as fp:
                pickle.dump(rf, fp)
        if self.parallel:
            mpiops.comm.barrier()
        # Mark that we are now trained
        self._trained = True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号