def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
if border_mode == 'same':
w_pad = pool_size[0] - 2 if pool_size[0] % 2 == 1 else pool_size[0] - 1
h_pad = pool_size[1] - 2 if pool_size[1] % 2 == 1 else pool_size[1] - 1
padding = (w_pad, h_pad)
elif border_mode == 'valid':
padding = (0, 0)
else:
raise Exception('Invalid border mode: ' + str(border_mode))
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if dim_ordering == 'tf':
x = x.dimshuffle((0, 3, 1, 2))
if pool_mode == 'max':
pool_out = pool.pool_2d(x, ds=pool_size, st=strides,
ignore_border=True,
padding=padding,
mode='max')
elif pool_mode == 'avg':
pool_out = pool.pool_2d(x, ds=pool_size, st=strides,
ignore_border=True,
padding=padding,
mode='average_exc_pad')
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))
if border_mode == 'same':
expected_width = (x.shape[2] + strides[0] - 1) // strides[0]
expected_height = (x.shape[3] + strides[1] - 1) // strides[1]
pool_out = pool_out[:, :,
: expected_width,
: expected_height]
if dim_ordering == 'tf':
pool_out = pool_out.dimshuffle((0, 2, 3, 1))
return pool_out
评论列表
文章目录