Stock_Prediction_Model_Stateless_LSTM.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:StockRecommendSystem 作者: doncat99 项目源码 文件源码
def train_data(self, data_feature, window, LabelColumnName):
        # history = History()

        #X_train, y_train, X_test, y_test = self.prepare_train_test_data(data_feature, LabelColumnName)
        X_train, y_train, X_test, y_test = self.prepare_train_data(data_feature, LabelColumnName)
        model = self.build_model(window, X_train, y_train, X_test, y_test)

        model.fit(
            X_train,
            y_train,
            batch_size=self.paras.batch_size,
            epochs=self.paras.epoch,
            # validation_split=self.paras.validation_split,
            # validation_data = (X_known_lately, y_known_lately),
            # callbacks=[history],
            # shuffle=True,
            verbose=self.paras.verbose
        )
        # save model
        self.save_training_model(model, window)
        recall_train, tmp = self.predict(model, X_train, y_train)
        # print('train recall is', recall_train)
        # print(' ############## validation on test data ############## ')
        recall_test, tmp = self.predict(model, X_test, y_test)
        # print('test recall is',recall_test)

        # plot training loss/ validation loss
        if self.paras.plot:
            self.plot_training_curve(history)

        return model


    ###################################
    ###                             ###
    ###         Predicting          ###
    ###                             ###
    ###################################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号