def forward(self, input):
mu = torch.mean(input, dim=-1, keepdim=True)
sigma = torch.std(input, dim=-1, keepdim=True).clamp(min=self.eps)
output = (input - mu) / sigma
return output * self.weight.expand_as(output) + self.bias.expand_as(output)
评论列表
文章目录