def get_epoch_indexes(self):
B = self.batch_size
K = self.num_classes
M = self.num_per_class
N = K * M # number of total examples
num_batches = M * int(K // B) # number of batches per epoch
indexes = np.arange(N, dtype=np.int32).reshape(K, M)
epoch_indexes = []
for m in range(M):
perm = np.random.permutation(K)
c_batches = np.array_split(perm, num_batches // M)
for c_batch in c_batches:
b = len(c_batch) # actual number of examples of this batch
indexes_anchor = M * c_batch + m
positive_candidates = np.delete(indexes[c_batch], m, axis=1)
indexes_positive = positive_candidates[
range(b), np.random.choice(M - 1, size=b)]
epoch_indexes.append((indexes_anchor, indexes_positive))
return epoch_indexes
评论列表
文章目录