def __iter__(self):
base_samples = torch.arange(0, len(self.weights)).long()
remaining = self.num_samples - len(self.weights)
over_samples = torch.multinomial(self.weights, remaining, True)
samples = torch.cat((base_samples, over_samples), dim=0)
print('num samples', len(samples))
return (samples[i] for i in torch.randperm(len(samples)))
评论列表
文章目录