def output_func(self, input): # In input we get a tensor (batch_size, nwords, ndim) return downsample.max_pool_2d(input=input, ds=self.pool_size, ignore_border=True)