def __init__(self, dataset, reweighting, model, large_batch=1024,
forward_batch_size=128, steps_per_epoch=300, recompute=2,
s_e=(1, 1), n_epochs=1):
super(OnlineBatchSelectionSampler, self).__init__(
dataset,
reweighting,
model,
large_batch=large_batch,
forward_batch_size=forward_batch_size
)
# The configuration of OnlineBatchSelection
self.steps_per_epoch = steps_per_epoch
self.recompute = recompute
self.s_e = s_e
self.n_epochs = n_epochs
# Mutable variables to be updated
self._batch = 0
self._epoch = 0
self._raw_scores = np.ones((len(dataset.train_data),))
self._scores = np.ones_like(self._raw_scores)
self._ranks = np.arange(len(dataset.train_data))
评论列表
文章目录