def compute_output_shape(self, input_shape): # We get two inputs assert len(input_shape) == 2 return (input_shape[0][0], 1)