def th_gather_2d(input, coords): inds = coords[:, 0]*input.size(1) + coords[:, 1] x = torch.index_select(th_flatten(input), 0, inds) return x.view(coords.size(0))