def forward(self, x, samples):
mean = self.mean_drop(x)
mean = self.mean_lin1(mean)
mean = F.relu(mean)
mean = self.mean_drop(mean)
mean = self.mean_lin2(mean)
variances = self.vars_drop(x)
variances = self.vars_lin1(variances)
variances = F.relu(variances)
variances = self.vars_drop(variances)
variances = self.vars_lin2(variances)
variances = F.softplus(variances) * self.softplus_boost
return True, torch.cat([mean, variances], dim=1) # TODO: Transform mean and variances in the same fashion as in ProposalNormal
评论列表
文章目录