RegressionDecisionTree.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:AirTicketPredicting 作者: junlulocky 项目源码 文件源码
def drawValidationCurve(self):
        """
        To draw the validation curve
        :return:NA
        """
        X, y = self.X_train, self.y_train.ravel()
        indices = np.arange(y.shape[0])
        #np.random.shuffle(indices)
        X, y = X[indices], y[indices]

        train_sizes = range(2,60)
        train_scores, valid_scores = validation_curve(DecisionTreeRegressor(max_features=None), X, y, "max_depth",
                                              train_sizes, cv=5, scoring='mean_squared_error')
        train_scores = -1.0/5 *train_scores
        valid_scores = -1.0/5 *valid_scores

        train_scores_mean = np.mean(train_scores, axis=1)
        train_scores_std = np.std(train_scores, axis=1)
        valid_scores_mean = np.mean(valid_scores, axis=1)
        valid_scores_std = np.std(valid_scores, axis=1)

        plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
                     train_scores_mean + train_scores_std, alpha=0.1,
                     color="r")
        plt.fill_between(train_sizes, valid_scores_mean - valid_scores_std,
                         valid_scores_mean + valid_scores_std, alpha=0.1, color="g")
        plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
                 label="Training MSE")
        plt.plot(train_sizes, valid_scores_mean, '*-', color="g",
                 label="Cross-validation MSE")

        plt.legend(loc="best")

        plt.xlabel('Max Depth')
        plt.ylabel('MSE')
        plt.title('Validation Curve with Decision \nTree Regression on the parameter of Max Depth')
        plt.grid(True)
        plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号