def solve_closest_pair_n_logn2(points):
def closest_pair(L, R, points):
if L == R: return 0x7fffffff, points[L], points[R] # return int max
if R - L == 1: return euclidean_dis_pow(points[L], points[R]), points[L], points[R]
mid = (L + R) >> 1
d, p1, p2 = closest_pair(L, mid, points)
d2, p3, p4 = closest_pair(mid + 1, R, points)
if d > d2:
d, p1, p2 = d2, p3, p4
min_x = points[mid][0] - d
max_x = points[mid][0] + d
suspect = [points[i] for i in range(L, R + 1) if min_x <= points[i][0] <= max_x]
suspect.sort(key=lambda x: x[1])
n = len(suspect)
for i in range(n):
for j in range(i + 1, n):
if suspect[j][1] - suspect[i][1] > d: break
t = euclidean_dis_pow(suspect[i], suspect[j])
if t < d:
d = t
p1, p2 = suspect[i], suspect[j]
return d, p1, p2
points.sort(key=cmp_to_key(lambda x, y: x[0] - y[0] if x[0] != y[0] else x[1] - y[1]))
return closest_pair(0, len(points) - 1, points)
评论列表
文章目录