twosls.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DeepIV 作者: jhartford 项目源码 文件源码
def fit_twosls(x, z, t, y):
    '''
    Two stage least squares with polynomial basis function.
    '''
    params = dict(poly__degree=range(1,4),
                  ridge__alpha=np.logspace(-5, 5, 11))
    pipe = Pipeline([('poly', PolynomialFeatures()),
                        ('ridge', Ridge())])
    stage_1 = GridSearchCV(pipe, param_grid=params, cv=5)
    if z.shape[1] > 0:
        X = np.concatenate([x,z], axis=1)
    else:
        X = z
    stage_1.fit(X,t)
    t_hat = stage_1.predict(X)
    print("First stage paramers: " + str(stage_1.best_params_ ))

    pipe2 = Pipeline([('poly', PolynomialFeatures()),
                        ('ridge', Ridge())])
    stage_2 = GridSearchCV(pipe2, param_grid=params, cv=5)
    X2 = np.concatenate([x,t_hat], axis=1)
    stage_2.fit(X2, y)
    print("Best in sample score: %f" % stage_2.score(X2, y))
    print("Second stage paramers: " + str(stage_2.best_params_  ))

    def g_hat(x,z,t):
        X_new = np.concatenate([x, t], axis=1)
        return stage_2.predict(X_new)
    return g_hat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号