def _cluster(infn, outfn):
if path.exists(outfn + '.lock') or path.exists(outfn + '.txt'):
return
open(outfn + '.lock', 'w').close()
print("doing " + infn)
data, xyt, transform, untransform, valid, maxes = parse_file(infn)
if not valid:
return
bandwidth = args.bandwidth
if bandwidth < 0:
bandwidth = estimate_bandwidth(xyt, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(xyt)
centers = ms.cluster_centers_
radius = bandwidth
centers = untransform(centers)
radius = untransform(radius)
xyt = untransform(xyt)
labels = ms.labels_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
few = np.bincount(labels) < args.min_deaths
extract_battles(outfn + '.txt', data, ms, maxes, xyt, transform, untransform)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, too_few in zip(range(n_clusters_), few):
my_members = labels == k
cluster_center = centers[k]
if too_few:
col = 'black'
else:
col = next(colors)
ax.scatter(cluster_center[0], cluster_center[1], cluster_center[2], 'o', c=col, s=100)
ax.scatter(xyt[my_members, 0], xyt[my_members, 1], xyt[my_members, 2], c=col)
if args.show:
plt.show()
plt.savefig(outfn + ".png")
plt.close(fig)
try:
os.remove(outfn + '.lock')
except:
pass
评论列表
文章目录