def forward(self, weight, bias, input):
# Assert we're using cudnn
for i in ([weight, bias, input]):
if i is not None and not(cudnn.is_acceptable(i)):
raise Exception('You must be using CUDNN to use _EfficientBatchNorm')
# Create save variables
self.save_mean = self.running_mean.new()
self.save_mean.resize_as_(self.running_mean)
self.save_var = self.running_var.new()
self.save_var.resize_as_(self.running_var)
# Do forward pass - store in input variable
res = type(input)(self.storage)
res.resize_as_(input)
torch._C._cudnn_batch_norm_forward(
input, res, weight, bias, self.running_mean, self.running_var,
self.save_mean, self.save_var, self.training, self.momentum, self.eps
)
return res
densenet_efficient.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录