def updateOutput(self, input): input, mask = input torch.masked_select(input, mask, out=self.output) return self.output