def forward(self, x): h = self.encoder(x) mu, log_var = torch.chunk(h, 2, dim=1) # mean and log variance. z = self.reparametrize(mu, log_var) out = self.decoder(z) return out, mu, log_var