def train_logistic_regression(X_train, y_train, max_iter, learning_rate, fit_intercept=False):
""" Train a logistic regression model
Args:
X_train, y_train (numpy.ndarray, training data set)
max_iter (int, number of iterations)
learning_rate (float)
fit_intercept (bool, with an intercept w0 or not)
Returns:
numpy.ndarray, learned weights
"""
if fit_intercept:
intercept = np.ones((X_train.shape[0], 1))
X_train = np.hstack((intercept, X_train))
weights = np.zeros(X_train.shape[1])
for iteration in range(max_iter):
weights = update_weights_sgd(X_train, y_train, weights, learning_rate)
# Check the cost for every 2 (for example) iterations
if iteration % 2 == 0:
print(compute_cost(X_train, y_train, weights))
return weights
# Train the SGD model based on 10000 samples
3logistic_regression_from_scratch.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录