def __call__(self, x): mean = self.hidden_layers(x) var = F.broadcast_to( F.softplus(self.var_param), mean.shape) return distribution.GaussianDistribution(mean, var)