def updateOutput(self, input): input, mask = input torch.masked_select(self.output, input, mask) return self.output