def forward(self, input, compute_loss=False, avg_loss=True):
# compute posterior
en1 = F.softplus(self.en1_fc(input)) # en1_fc output
en2 = F.softplus(self.en2_fc(en1)) # encoder2 output
en2 = self.en2_drop(en2)
posterior_mean = self.mean_bn (self.mean_fc (en2)) # posterior mean
posterior_logvar = self.logvar_bn(self.logvar_fc(en2)) # posterior log variance
posterior_var = posterior_logvar.exp()
# take sample
eps = Variable(input.data.new().resize_as_(posterior_mean.data).normal_()) # noise
z = posterior_mean + posterior_var.sqrt() * eps # reparameterization
p = F.softmax(z) # mixture probability
p = self.p_drop(p)
# do reconstruction
recon = F.softmax(self.decoder_bn(self.decoder(p))) # reconstructed distribution over vocabulary
if compute_loss:
return recon, self.loss(input, recon, posterior_mean, posterior_logvar, posterior_var, avg_loss)
else:
return recon
评论列表
文章目录