def bn_gamma_beta(self, x):
if self.use_cuda:
ones = Parameter(torch.ones(x.size()[0], 1).cuda())
else:
ones = Parameter(torch.ones(x.size()[0], 1))
t = x + ones.mm(self.bn_beta)
if self.train_bn_scaling:
t = torch.mul(t, ones.mm(self.bn_gamma))
return t
评论列表
文章目录