def forward(self, X):
X = super(ProbabilisticDense, self).forward(X)
sigma_prior = math.exp(-3)
W_eps = Variable(torch.zeros(self.input_dim, self.output_dim))
W_eps = torch.normal(W_eps, std=sigma_prior)
self.W = W = self.W_mu + torch.log1p(torch.exp(self.W_rho)) * W_eps
b_eps = Variable(torch.zeros(self.output_dim))
b_eps = torch.normal(b_eps, std=sigma_prior)
self.b = b = self.b_mu + torch.log1p(torch.exp(self.b_rho)) * b_eps
XW = X @ W
return XW + b.expand_as(XW)
评论列表
文章目录