def get_balance_d():
n = len(rel)
a = np.arange(n)
np.random.shuffle(a)
n0 = 0; n1 = 0; indices = []
for i in a:
x = rel[i]
if n0 < n1 and x == 1: continue
if n1 < n0 and x == 0: continue
indices.append(i)
if x == 0: n0 += 1
if x == 1: n1 += 1
global bal_mat, bal_rel, bal_turk_data, bal_turk_data_uncer, bal_turk_data_id
bal_mat = mat[indices]
bal_rel = [rel[i] for i in indices]
#bal_turk_data = [turk_data[i] for i in indices]
#bal_turk_data_uncer = [turk_data_uncer[i] for i in indices]
bal_turk_data_id = [turk_data_id[i] for i in indices]
评论列表
文章目录