def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.conv1(self.equalInOut and out or x)
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(self.relu2(self.bn2(out)))
return torch.add((not self.equalInOut) and self.convShortcut(x) or x, out)
评论列表
文章目录