def predict_price(dates, prices, x):
dates = np.reshape(dates, (len(dates),1)) # converting to matrix of n X 1
prices = np.reshape(prices, (len(prices),1))
linear_mod = linear_model.LinearRegression() # defining the linear regression model
linear_mod.fit(dates, prices) # fitting the data points in the model
plt.scatter(dates, prices, color= 'black', label= 'Data') # plotting the initial datapoints
plt.plot(dates, linear_mod.predict(dates), color= 'red', label= 'Linear model') # plotting the line made by linear regression
plt.xlabel('Date')
plt.ylabel('Price')
plt.title('Linear Regression')
plt.legend()
plt.show()
return linear_mod.predict(x)[0][0], linear_mod.coef_[0][0], linear_mod.intercept_[0]
评论列表
文章目录