def create_minibatch_indices(n, minibatch_size, shuffling=True):
"""
:param n: total number of indices from which to pick from
:param minibatch_size: size of the minibatches (must be lower than n)
:return: (list of random indices, number of random duplicate indices in the last minibatch to complete it)
"""
if shuffling:
all_indices = np.random.permutation(n) # shuffle order randomly
else:
all_indices = np.arange(n)
n_steps = (n - 1) // minibatch_size + 1 # how many batches fit per epoch
n_rem = n_steps * minibatch_size - n # remainder
if n_rem > 0:
inds_to_add = np.random.randint(0, n_rem, size=n_rem)
all_indices = np.concatenate((all_indices, inds_to_add))
return np.split(all_indices, n_steps), n_rem
评论列表
文章目录