def output_func(self, input): # In input we get a tensor (batch_size, nwords, ndim) return pool.pool_2d(input=input, ds=self.pool_size, ignore_border=True)