def get_output_for(self, input, **kwargs):
if self.pad == 'strictsamex':
assert(self.stride[0] == 1)
kk = self.pool_size[0]
ll = int(np.ceil(kk/2.))
# rr = kk-ll
# pad = (ll, 0)
pad = [(ll, 0)]
length = input.shape[2]
self.ignore_border = True
input = padding.pad(input, pad, batch_ndim=2)
pad = (0, 0)
else:
pad = self.pad
pooled = pool.pool_2d(input,
ds=self.pool_size,
st=self.stride,
ignore_border=self.ignore_border,
padding=pad,
mode=self.mode,
)
if self.pad == 'strictsamex':
pooled = pooled[:, :, :length or None, :]
return pooled
# add 'strictsamex' method for pad
评论列表
文章目录