def define_model(self):
#if self.modeltype == "AR" :
# return statsmodels.tsa.ar_model.AR(max_order=self.parameters['max_order'])
if self.modeltype == "RandomForest" :
return ensemble.RandomForestRegressor(n_estimators=self.parameters['n_estimators'])
#return ensemble.RandomForestClassifier(
# n_estimators=self.parameters['n_estimators'])
elif self.modeltype == "LinearRegression" :
return linear_model.LinearRegression()
elif self.modeltype == "Lasso" :
return linear_model.Lasso(
alpha=self.parameters['alpha'])
elif self.modeltype == "ElasticNet" :
return linear_model.ElasticNet(
alpha=self.parameters['alpha'],
l1_ratio=self.parameters['l1_ratio'])
elif self.modeltype == "SVR" :
return SVR(
C=self.parameters['C'],
epsilon=self.parameters['epsilon'],
kernel=self.parameters['kernel'])
#elif self.modeltype == 'StaticModel':
# return StaticModel (
# parameters=self.parameters
# )
#elif self.modeltype == 'AdvancedStaticModel':
# return AdvancedStaticModel (
# parameters=self.parameters
# )
# elif self.modeltype == 'SGDRegressor' :
# print(self.parameters)
# return linear_model.SGDRegressor(
# loss=self.parameters['loss'],
# penalty=self.parameters['penalty'],
# l1_ratio=self.parameters['l1_ratio'])
else:
raise ConfigError("Unsupported model {0}".format(self.modeltype))
评论列表
文章目录