def check_convergence(hdf5_file, iteration, convergence_iter, max_iter):
"""If the estimated number of clusters has not changed for 'convergence_iter'
consecutive iterations in a total of 'max_iter' rounds of message-passing,
the procedure herewith returns 'True'.
Otherwise, returns 'False'.
Parameter 'iteration' identifies the run of message-passing
that has just completed.
"""
Worker.hdf5_lock.acquire()
with tables.open_file(hdf5_file, 'r+') as fileh:
A = fileh.root.aff_prop_group.availabilities
R = fileh.root.aff_prop_group.responsibilities
P = fileh.root.aff_prop_group.parallel_updates
N = A.nrows
diag_ind = np.diag_indices(N)
E = (A[diag_ind] + R[diag_ind]) > 0
P[:, iteration % convergence_iter] = E
e_mat = P[:]
K = E.sum(axis = 0)
Worker.hdf5_lock.release()
if iteration >= convergence_iter:
se = e_mat.sum(axis = 1)
unconverged = (np.sum((se == convergence_iter) + (se == 0)) != N)
if (not unconverged and (K > 0)) or (iteration == max_iter):
return True
return False
评论列表
文章目录