pytorch_utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:mcnPyTorch 作者: albanie 项目源码 文件源码
def forward(self, x):
        # mini = list(self.features.children())[:4]
        # mini_f = torch.nn.modules.Sequential(*mini) ;
        # y = mini_f(x)
        # ipdb.set_trace()
        # mini = list(self.features.children())

        x = self.features(x)
        if self.flatten_loc == 'classifier':
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
        elif self.flatten_loc == 'end':
            x = self.classifier(x)
            x = x.view(x.size(0), -1)
        else:
            msg = 'unrecognised flatten_loc: {}'.format(self.flatten_loc)
            raise ValueError(msg)
        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号