def forward(self, x):
# mini = list(self.features.children())[:4]
# mini_f = torch.nn.modules.Sequential(*mini) ;
# y = mini_f(x)
# ipdb.set_trace()
# mini = list(self.features.children())
x = self.features(x)
if self.flatten_loc == 'classifier':
x = x.view(x.size(0), -1)
x = self.classifier(x)
elif self.flatten_loc == 'end':
x = self.classifier(x)
x = x.view(x.size(0), -1)
else:
msg = 'unrecognised flatten_loc: {}'.format(self.flatten_loc)
raise ValueError(msg)
return x
评论列表
文章目录