def forward(self, x):
if isinstance(x, list):
x, is_list, features = x[0], True, x[1:]
else:
is_list, features = False, None
residual = x
conv_a = self.conv_a(x)
bn_a = self.bn_a(conv_a)
relu_a = F.relu(bn_a, inplace=True)
conv_b = self.conv_b(relu_a)
bn_b = self.bn_b(conv_b)
if self.downsample is not None:
residual = self.downsample(x)
output = F.relu(residual + bn_b, inplace=True)
if is_list:
return [output] + features + [bn_a, bn_b]
else:
return output
评论列表
文章目录