def updateOutput(self, input): t = input[0] index = input[1] torch.index_select(self.output, t, self.dimension, index) return self.output