task.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:rnnlab 作者: phueb 项目源码 文件源码
def gen_task_mbs(self, style, test_fold_id):
        num_task_iterations = int(''.join([c for c in self.regime if c.isdigit()]))
        # make blocks
        if style == 'train':
            task_lines = list(chain(*[fold for n, fold in enumerate(self.task_folds)
                                      if n != test_fold_id]))
            windows_x, windows_y = self.make_windows(task_lines)
            block_x = np.tile(windows_x, [num_task_iterations, 1])
            block_y = np.tile(windows_y, [num_task_iterations, 1])
        elif style == 'test':
            task_lines = list(chain(*[fold for n, fold in enumerate(self.task_folds)
                                      if n == test_fold_id]))
            windows_x, windows_y = self.make_windows(task_lines)
            block_x = windows_x
            block_y = windows_y
        elif style == 'train1':
            task_lines = list(chain(*[fold for n, fold in enumerate(self.task_folds)
                                      if n != test_fold_id]))
            windows_x, windows_y = self.make_windows(task_lines)
            block_x = windows_x
            block_y = windows_y
        else:
            raise AttributeError('rnnlab: Invalid arg to "style"')
        # split to mbs
        if not gcd(self.mb_size, len(block_x)) == self.mb_size:
            raise Exception(
                'rnnlab: Number of task_lines must be divisible by mb_size')
        num_splits = len(block_x) // self.mb_size
        # generate
        for x, y in zip(np.vsplit(block_x, num_splits),
                        np.vsplit(block_y, num_splits)):
            yield x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号