def fit(self, x, y, *args, **kwargs):
# set a different random seed for each thread
np.random.seed(self.random_state + mpiops.chunk_index)
if self.parallel:
process_rfs = np.array_split(range(self.forests),
mpiops.chunks)[mpiops.chunk_index]
else:
process_rfs = range(self.forests)
for t in process_rfs:
print('training forest {} using '
'process {}'.format(t, mpiops.chunk_index))
# change random state in each forest
self.kwargs['random_state'] = np.random.randint(0, 10000)
rf = RandomForestTransformed(
target_transform=self.target_transform,
n_estimators=self.n_estimators,
**self.kwargs
)
rf.fit(x, y)
if self.parallel: # used in training
pk_f = join(self.temp_dir,
'rf_model_{}.pk'.format(t))
else: # used when parallel is false, i.e., during x-val
pk_f = join(self.temp_dir,
'rf_model_{}_{}.pk'.format(t, mpiops.chunk_index))
with open(pk_f, 'wb') as fp:
pickle.dump(rf, fp)
if self.parallel:
mpiops.comm.barrier()
# Mark that we are now trained
self._trained = True
评论列表
文章目录