ClassificationHmmGeneralize.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:AirTicketPredicting 作者: junlulocky 项目源码 文件源码
def __init__(self, isTrain):
        super(ClassificationHmmGeneralize, self).__init__(isTrain)
        # data preprocessing
        self.dataPreprocessing()

        self.dt_stump = DecisionTreeClassifier(max_depth=10)
        self.ada = AdaBoostClassifier(
            base_estimator=self.dt_stump,
            learning_rate=1,
            n_estimators=5,
            algorithm="SAMME.R")

        # load the general data
        # feature 0~7: flight number dummy variables
        # feature 8: departure date; feature 9: observed date state;
        # feature 10: minimum price; feature 11: maximum price

        # feature 12: output; feature 13: current price
        # feature 14: flight index
        self.X_general = np.load('inputGeneralClf_HmmParsed/X_train.npy')
        self.y_general = np.load('inputGeneralClf_HmmParsed/y_train.npy')
        self.y_general = self.y_general.reshape((self.y_general.shape[0], 1))
        self.y_general_price = np.load('inputGeneralClf_HmmParsed/y_train_price.npy')
        self.y_general_price = self.y_general_price.reshape((self.y_general_price.shape[0], 1))
        self.y_general_index = np.load('inputGeneralClf_HmmParsed/y_index.npy')
        self.y_general_index = self.y_general_index.reshape((self.y_general_index.shape[0], 1))



        self.routes_general = ["BGY_OTP", # route 1
                "BUD_VKO", # route 2
                "CRL_OTP", # route 3
                "CRL_WAW", # route 4
                "LTN_OTP", # route 5
                "LTN_PRG", # route 6
                "OTP_BGY", # route 7
                "OTP_CRL", # route 8
                "OTP_LTN", # route 9
                "PRG_LTN", # route 10
                "VKO_BUD", # route 11
                "WAW_CRL"] # route 12
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号