2linear_regression.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Python-Machine-Learning-By-Example 作者: PacktPublishing 项目源码 文件源码
def train_linear_regression(X_train, y_train, max_iter, learning_rate, fit_intercept=False):
    """ Train a linear regression model with gradient descent
    Args:
        X_train, y_train (numpy.ndarray, training data set)
        max_iter (int, number of iterations)
        learning_rate (float)
        fit_intercept (bool, with an intercept w0 or not)
    Returns:
        numpy.ndarray, learned weights
    """
    if fit_intercept:
        intercept = np.ones((X_train.shape[0], 1))
        X_train = np.hstack((intercept, X_train))
    weights = np.zeros(X_train.shape[1])
    for iteration in range(max_iter):
        weights = update_weights_gd(X_train, y_train, weights, learning_rate)
        # Check the cost for every 100 (for example) iterations
        if iteration % 100 == 0:
            print(compute_cost(X_train, y_train, weights))
    return weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号