def validate_formula(formula, training_data, column_being_predicted, cross_val_n=3, validation_size=.10):
'''
Accept a formula in the StatsModels.formula.api style, some training data and
some test values that must match the value being predicted by the formula.
returns: trained_model, cross_scores
'''
cross_val_scores = []
for _ in xrange(cross_val_n):
X_train, X_test, _, _ = train_test_split(
training_data,
training_data[column_being_predicted],
test_size=validation_size
)
model = smf.ols(formula=formula, data=X_train).fit()
test_values = X_test[column_being_predicted]
score = root_mean_log_squared_error(model, X_test, test_values)
cross_val_scores.append(score)
return (model, cross_val_scores)
评论列表
文章目录