def forward(self, x):
# If we're not training this layer, set to eval mode so that we use
# running batchnorm stats (both for time-saving and to avoid updating
# said stats).
if not self.active:
self.eval()
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
# If we're not active, return a detached output to prevent backprop.
if self.active:
return out
else:
return out.detach()
评论列表
文章目录