def prepare_batches(self, batch_data, chunks, **kwargs):
x, x_lens, ys, ys_lens = batch_data
batch_dim = 0 if self.batch_first else 1
x_list = x.chunk(chunks, 0)
x_lens_list = x_lens.chunk(chunks, 0)
ys_list = ys.chunk(chunks, batch_dim)
ys_lens_list = ys_lens.chunk(chunks, batch_dim)
inp_list = [x_list, x_lens_list, ys_list, ys_lens_list]
data_list = []
for inp in zip(*inp_list):
data = self.prepare_batch(inp, **kwargs)
data_list.append(data)
data_list = list(zip(*data_list))
ret_list = []
for data in data_list:
data = [d.unsqueeze(0) for d in data]
data = torch.cat(data)
ret_list.append(data)
return ret_list
评论列表
文章目录