def forward(self, weight, bias, input):
# Assert we're using cudnn
for i in ([weight, bias, input]):
if i is not None and not(cudnn.is_acceptable(i)):
raise Exception('You must be using CUDNN to use EfficientBatchNorm')
# Create save variables
self.save_mean = self.running_mean.new()
self.save_mean.resize_as_(self.running_mean)
self.save_var = self.running_var.new()
self.save_var.resize_as_(self.running_var)
# Do forward pass - store in input variable
cur_device_id = weight.get_device()
res = type(input)(self.storage.change_device(cur_device_id)).resize_as_(input)
assert weight.get_device() == res.get_device(), \
"input and output should be on the same chip!"
torch._C._cudnn_batch_norm_forward(input, res,
weight, bias,
self.running_mean, self.running_var,
self.save_mean, self.save_var,
self.training,
self.momentum,
self.eps)
return res
densenet_efficient_multi_gpu.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录