def sample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) return torch.normal(self.mean.expand(shape), self.std.expand(shape))