def unbundle(state): if state is None: return itertools.repeat(None) return torch.split(torch.cat(state, 1), 1, 0)