def updateOutput(self, input): t = input[0] index = input[1] torch.index_select(t, self.dimension, index, out=self.output) return self.output