def train_data(self, data_feature, window, LabelColumnName):
# history = History()
#X_train, y_train, X_test, y_test = self.prepare_train_test_data(data_feature, LabelColumnName)
X_train, y_train, X_test, y_test = self.prepare_train_data(data_feature, LabelColumnName)
model = self.build_model(window, X_train, y_train, X_test, y_test)
model.fit(
X_train,
y_train,
batch_size=self.paras.batch_size,
epochs=self.paras.epoch,
# validation_split=self.paras.validation_split,
# validation_data = (X_known_lately, y_known_lately),
# callbacks=[history],
# shuffle=True,
verbose=self.paras.verbose
)
# save model
self.save_training_model(model, window)
recall_train, tmp = self.predict(model, X_train, y_train)
# print('train recall is', recall_train)
# print(' ############## validation on test data ############## ')
recall_test, tmp = self.predict(model, X_test, y_test)
# print('test recall is',recall_test)
# plot training loss/ validation loss
if self.paras.plot:
self.plot_training_curve(history)
return model
###################################
### ###
### Predicting ###
### ###
###################################
Stock_Prediction_Model_Stateless_LSTM.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录