def form_set_data(labels, max_num, verbose=False):
"""Generate label sets from sample labels.
For each sample, generate a set by random sampling within the same class.
Set is a tensor
"""
# group sample ids based on label.
label_ids = {}
for idx in range(labels.size):
if labels[idx] not in label_ids:
label_ids[labels[idx]] = []
label_ids[labels[idx]].append(idx)
set_ids = {}
for idx in range(labels.size):
samp_ids = label_ids[labels[idx]][:]
samp_num = min(max_num, len(samp_ids))
set_ids[idx] = rand.sample(samp_ids, samp_num)
if verbose:
print "set {} formed.".format(idx)
return set_ids
评论列表
文章目录