def forward(self, x):
if isinstance(x, list):
assert len(x) == 1, 'The length of inputs must be one vs {}'.format(len(x))
x, is_list = x[0], True
else:
x, is_list = x, False
x = self.conv_1_3x3(x)
x = F.relu(self.bn_1(x), inplace=True)
if is_list: x = [x]
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
if is_list:
x, features = x[0], x[1:]
else:
features = None
x = self.avgpool(x)
x = x.view(x.size(0), -1)
cls = self.classifier(x)
if is_list: return cls, features
else: return cls
评论列表
文章目录