def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False):
"""Selectable global pooling function with dynamic input kernel size
"""
if pool_type == 'avgmaxc':
x = torch.cat([
F.avg_pool2d(
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad),
F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
], dim=1)
elif pool_type == 'avgmax':
x_avg = F.avg_pool2d(
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
x = 0.5 * (x_avg + x_max)
elif pool_type == 'max':
x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding)
else:
if pool_type != 'avg':
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
x = F.avg_pool2d(
x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad)
return x
adaptive_avgmax_pool.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录