linear.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:regression-stock-prediction 作者: chaitjo 项目源码 文件源码
def predict_price(dates, prices, x):
    dates = np.reshape(dates, (len(dates),1)) # converting to matrix of n X 1
    prices = np.reshape(prices, (len(prices),1))

    linear_mod = linear_model.LinearRegression() # defining the linear regression model
    linear_mod.fit(dates, prices) # fitting the data points in the model

    plt.scatter(dates, prices, color= 'black', label= 'Data') # plotting the initial datapoints 
    plt.plot(dates, linear_mod.predict(dates), color= 'red', label= 'Linear model') # plotting the line made by linear regression
    plt.xlabel('Date')
    plt.ylabel('Price')
    plt.title('Linear Regression')
    plt.legend()
    plt.show()

    return linear_mod.predict(x)[0][0], linear_mod.coef_[0][0], linear_mod.intercept_[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号