def output_strong(tcs, alpha, mis, labels, prefix=''):
f = safe_open(prefix + '/text_files/most_deterministic_groups.txt', 'w+')
m, n = alpha.shape
topk = 5
ixy = np.clip(np.sum(alpha * mis, axis=1) - tcs, 0, np.inf)
hys = np.array([entropy(labels[:, j]) for j in range(m)]).clip(1e-6)
ntcs = [(np.sum(np.sort(alpha[j] * mis[j])[-topk:]) - ixy[j]) / ((topk - 1) * hys[j]) for j in range(m)]
f.write('Group num., NTC\n')
for j, ntc in sorted(enumerate(ntcs), key=lambda q: -q[1]):
f.write('%d, %0.3f\n' % (j, ntc))
f.close()
评论列表
文章目录