def forward(self, input):
batch_size = input.size(0)
num_channels = input.size(1)
h = input.size(2)
w = input.size(3)
n = h * w # number of regions
kmax = self.get_positive_k(self.kmax, n)
kmin = self.get_positive_k(self.kmin, n)
sorted, indices = input.new(), input.new().long()
torch.sort(input.view(batch_size, num_channels, n), dim=2, descending=True, out=(sorted, indices))
self.indices_max = indices.narrow(2, 0, kmax)
output = sorted.narrow(2, 0, kmax).sum(2).div_(kmax)
if kmin > 0 and self.alpha is not 0:
self.indices_min = indices.narrow(2, n - kmin, kmin)
output.add_(sorted.narrow(2, n - kmin, kmin).sum(2).mul_(self.alpha / kmin)).div_(2)
self.save_for_backward(input)
return output.view(batch_size, num_channels)
评论列表
文章目录