def gather_index(input, index): assert input.dim() == 2 and index.dim() == 1 index = index.unsqueeze(1).expand_as(input) output = torch.gather(input, 1, index) return output[:, 0]