def forward(self, input):
torch.randn(self.epsilon_weight.size(), out=self.epsilon_weight)
bias = self.bias
if bias is not None:
torch.randn(self.epsilon_bias.size(), out=self.epsilon_bias)
bias = bias + self.sigma_bias * Variable(self.epsilon_bias)
return F.linear(input, self.weight + self.sigma_weight * Variable(self.epsilon_weight), bias)
评论列表
文章目录