def index(batch_size, x): idx = torch.arange(0, batch_size).long() idx = torch.unsqueeze(idx, -1) return torch.cat((idx, x), dim=1)