def neg_samples(self, batch_size: int): n_samples = batch_size * self.n_neg_samples return multinomial(self.output_dist, num_samples=n_samples, replacement=True)