def _feature_reorg(self, input, stride=2):
N, C, H, W = input.size()
assert H == W, "H and W is not equal"
w_new = int(W / 2)
idx_left = torch.arange(0, w_new).long().cuda()
idx_right = torch.arange(w_new, W).long().cuda()
idx_left = Variable(idx_left)
idx_right = Variable(idx_right)
output_left = input.index_select(dim=3, index=idx_left)
output_right = input.index_select(dim=3, index=idx_right)
output_left = output_left.view(N, -1, w_new, w_new)
output_right = output_right.view(N, -1, w_new, w_new)
output_cat = torch.cat((output_left, output_right), dim=2)
output = output_cat.view(N, -1, w_new, w_new)
return output
评论列表
文章目录