def __init__(self, isTrain):
super(ClassificationHmmGeneralize, self).__init__(isTrain)
# data preprocessing
self.dataPreprocessing()
self.dt_stump = DecisionTreeClassifier(max_depth=10)
self.ada = AdaBoostClassifier(
base_estimator=self.dt_stump,
learning_rate=1,
n_estimators=5,
algorithm="SAMME.R")
# load the general data
# feature 0~7: flight number dummy variables
# feature 8: departure date; feature 9: observed date state;
# feature 10: minimum price; feature 11: maximum price
# feature 12: output; feature 13: current price
# feature 14: flight index
self.X_general = np.load('inputGeneralClf_HmmParsed/X_train.npy')
self.y_general = np.load('inputGeneralClf_HmmParsed/y_train.npy')
self.y_general = self.y_general.reshape((self.y_general.shape[0], 1))
self.y_general_price = np.load('inputGeneralClf_HmmParsed/y_train_price.npy')
self.y_general_price = self.y_general_price.reshape((self.y_general_price.shape[0], 1))
self.y_general_index = np.load('inputGeneralClf_HmmParsed/y_index.npy')
self.y_general_index = self.y_general_index.reshape((self.y_general_index.shape[0], 1))
self.routes_general = ["BGY_OTP", # route 1
"BUD_VKO", # route 2
"CRL_OTP", # route 3
"CRL_WAW", # route 4
"LTN_OTP", # route 5
"LTN_PRG", # route 6
"OTP_BGY", # route 7
"OTP_CRL", # route 8
"OTP_LTN", # route 9
"PRG_LTN", # route 10
"VKO_BUD", # route 11
"WAW_CRL"] # route 12
ClassificationHmmGeneralize.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录