date_splitter.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:guacml 作者: guacml 项目源码 文件源码
def cv_splits(self, input):
        dates = input[self.date_split_col]
        left = dates.max()
        split_points = []
        for i in range(self.n_folds):
            right = left
            left = left - pd.Timedelta(days=self.prediction_length)
            split_points.append((left, right))
        split_points.reverse()
        if split_points[0][0] - dates.min() < pd.Timedelta(days=self.prediction_length):
            raise Exception('Training set is shorter than the prediction length. Use a less'
                            'cross validation folds or a shorter prediction length')

        split_indices = []
        for left, right in split_points:
            train = input[dates < left]
            cv = input[(dates >= left) & (dates < right)]
            split_indices.append([train.index.values, cv.index.values])

        return split_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号