def get_output(self, input, **kwargs):
pooled = pool_2d(input,
ws=self.pool_size,
stride=self.stride,
ignore_border=self.ignore_border,
pad=self.pad,
mode=self.mode,
)
return pooled