def sample_data(data, num_sample):
""" data is in N x ...
we want to keep num_samplexC of them.
if N > num_sample, we will randomly keep num_sample of them.
if N < num_sample, we will randomly duplicate samples.
"""
N = data.shape[0]
if (N == num_sample):
return data, range(N)
elif (N > num_sample):
sample = np.random.choice(N, num_sample)
return data[sample, ...], sample
else:
sample = np.random.choice(N, num_sample-N)
dup_data = data[sample, ...]
return np.concatenate([data, dup_data], 0), range(N)+list(sample)
评论列表
文章目录