def iter_combinatorial_pairs(queue, num_examples, batch_size, interval,
num_classes, augment_positive=False):
num_examples_per_class = num_examples // num_classes
pairs = np.array(list(itertools.combinations(range(num_examples), 2)))
if augment_positive:
additional_positive_pairs = make_positive_pairs(
num_classes, num_examples_per_class, num_classes - 1)
pairs = np.concatenate((pairs, additional_positive_pairs))
num_pairs = len(pairs)
num_batches = num_pairs // batch_size
perm = np.random.permutation(num_pairs)
for i, batch_indexes in enumerate(np.array_split(perm, num_batches)):
if i % interval == 0:
x, c = queue.get()
x = x.astype(np.float32) / 255.0
c = c.ravel()
indexes0, indexes1 = pairs[batch_indexes].T
x0, x1, c0, c1 = x[indexes0], x[indexes1], c[indexes0], c[indexes1]
t = np.int32(c0 == c1) # 1 if x0 and x1 are same class, 0 otherwise
yield x0, x1, t
评论列表
文章目录