test_segnet.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-segnet 作者: pfnet-research 项目源码 文件源码
def test_remove_link(self):
        opt = optimizers.MomentumSGD(lr=0.01)
        # Update each depth
        for depth in six.moves.range(1, self.n_encdec + 1):
            model = segnet.SegNet(self.n_encdec, self.n_classes,
                                  self.x_shape[1], self.n_mid)
            model = segnet.SegNetLoss(
                model, class_weight=None, train_depth=depth)
            opt.setup(model)

            # Deregister non-target links from opt
            if depth > 1:
                model.predictor.remove_link('conv_cls')
            for d in range(1, self.n_encdec + 1):
                if d != depth:
                    model.predictor.remove_link('encdec{}'.format(d))

            for name, link in model.namedparams():
                if depth > 1:
                    self.assertTrue(
                        'encdec{}'.format(depth) in name)
                else:
                    self.assertTrue(
                        'encdec{}'.format(depth) in name or 'conv_cls' in name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号