def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) return torch.bernoulli(self.probs.expand(shape))