def split_ps_reuse(point_set, level, pos, tree, cutdim):
sz = point_set.size()
num_points = np.array(sz)[0]/2
max_value = point_set.max(dim=0)[0]
min_value = -(-point_set).max(dim=0)[0]
diff = max_value - min_value
dim = torch.max(diff, dim = 1)[1][0,0]
cut = torch.median(point_set[:,dim])[0][0]
left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))
if torch.numel(left_idx) < num_points:
left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
if torch.numel(right_idx) < num_points:
right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)
left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
tree[level+1][pos * 2] = left_ps
tree[level+1][pos * 2 + 1] = right_ps
cutdim[level][pos * 2] = dim
cutdim[level][pos * 2 + 1] = dim
return
评论列表
文章目录