def forward(self, z):
if z.size(1) == 1:
return z
mu = torch.mean(z, keepdim=True, dim=-1)
sigma = torch.std(z, keepdim=True, dim=-1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
return ln_out
Modules.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录