def _fit(self, dataset):
est = self.getOrDefault(self.estimator)
epm = self.getOrDefault(self.estimatorParamMaps)
numModels = len(epm)
eva = self.getOrDefault(self.evaluator)
tRatio = self.getOrDefault(self.trainRatio)
seed = self.getOrDefault(self.seed)
randCol = self.uid + "_rand"
df = dataset.select("*", rand(seed).alias(randCol))
metrics = [0.0] * numModels
condition = (df[randCol] >= tRatio)
validation = df.filter(condition)
train = df.filter(~condition)
for j in range(numModels):
model = est.fit(train, epm[j])
metric = eva.evaluate(model.transform(validation, epm[j]))
metrics[j] += metric
if eva.isLargerBetter():
bestIndex = np.argmax(metrics)
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
return self._copyValues(TrainValidationSplitModel(bestModel, metrics))
评论列表
文章目录