common_funcs.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:model_sweeper 作者: akimovmike 项目源码 文件源码
def test_multicollinearity(df, target_name, r2_threshold = 0.89):
    '''Tests if any of the features could be predicted from others with R2 >= 0.89

    input: dataframe, name of target (to exclude)

   '''
    r2s = pd.DataFrame()
    for feature in df.columns.difference([target_name]):
        model = sk.linear_model.Ridge()
        model.fit(df[df.columns.difference([target_name,feature])], df[feature])

        pos = np.in1d(model.coef_, np.sort(model.coef_)[-5:])

        r2s = r2s.append(pd.DataFrame({'r2':sk.metrics.r2_score(df[feature],\
            model.predict(df[df.columns.difference([target_name, feature])])),\
            'predictors' : str(df.columns.difference([target_name, feature])[np.ravel(np.argwhere(pos == True))].tolist())}, index = [feature]))
        print('Testing', feature)

    print('-----------------')

    if len(r2s[r2s['r2'] >= r2_threshold]) > 0:
        print('Multicollinearity detected')
        print(r2s[r2s['r2'] >= r2_threshold])
    else:
        print('No multicollinearity')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号