def gen_task_mbs(self, style, test_fold_id):
num_task_iterations = int(''.join([c for c in self.regime if c.isdigit()]))
# make blocks
if style == 'train':
task_lines = list(chain(*[fold for n, fold in enumerate(self.task_folds)
if n != test_fold_id]))
windows_x, windows_y = self.make_windows(task_lines)
block_x = np.tile(windows_x, [num_task_iterations, 1])
block_y = np.tile(windows_y, [num_task_iterations, 1])
elif style == 'test':
task_lines = list(chain(*[fold for n, fold in enumerate(self.task_folds)
if n == test_fold_id]))
windows_x, windows_y = self.make_windows(task_lines)
block_x = windows_x
block_y = windows_y
elif style == 'train1':
task_lines = list(chain(*[fold for n, fold in enumerate(self.task_folds)
if n != test_fold_id]))
windows_x, windows_y = self.make_windows(task_lines)
block_x = windows_x
block_y = windows_y
else:
raise AttributeError('rnnlab: Invalid arg to "style"')
# split to mbs
if not gcd(self.mb_size, len(block_x)) == self.mb_size:
raise Exception(
'rnnlab: Number of task_lines must be divisible by mb_size')
num_splits = len(block_x) // self.mb_size
# generate
for x, y in zip(np.vsplit(block_x, num_splits),
np.vsplit(block_y, num_splits)):
yield x, y
评论列表
文章目录