def convert(batch, device):
def to_device_batch(batch):
if device is None:
return batch
elif device < 0:
return [chainer.dataset.to_device(device, x) for x in batch]
else:
xp = cuda.cupy.get_array_module(*batch)
concat = xp.concatenate(batch, axis=0)
sections = np.cumsum([len(x) for x in batch[:-1]], dtype='i')
concat_dev = chainer.dataset.to_device(device, concat)
batch_dev = cuda.cupy.split(concat_dev, sections)
return batch_dev
return {'xs': to_device_batch([x for x, _ in batch]),
'ys': to_device_batch([y for _, y in batch])}
评论列表
文章目录