def test_multitarget():
# Assure that estimators receiving multidimensional y do the right thing
X = diabetes.data
Y = np.vstack([diabetes.target, diabetes.target ** 2]).T
n_targets = Y.shape[1]
for estimator in (linear_model.LassoLars(), linear_model.Lars()):
estimator.fit(X, Y)
Y_pred = estimator.predict(X)
Y_dec = assert_warns(DeprecationWarning, estimator.decision_function, X)
assert_array_almost_equal(Y_pred, Y_dec)
alphas, active, coef, path = (estimator.alphas_, estimator.active_,
estimator.coef_, estimator.coef_path_)
for k in range(n_targets):
estimator.fit(X, Y[:, k])
y_pred = estimator.predict(X)
assert_array_almost_equal(alphas[k], estimator.alphas_)
assert_array_almost_equal(active[k], estimator.active_)
assert_array_almost_equal(coef[k], estimator.coef_)
assert_array_almost_equal(path[k], estimator.coef_path_)
assert_array_almost_equal(Y_pred[:, k], y_pred)
评论列表
文章目录