def test_remove_link(self):
opt = optimizers.MomentumSGD(lr=0.01)
# Update each depth
for depth in six.moves.range(1, self.n_encdec + 1):
model = segnet.SegNet(self.n_encdec, self.n_classes,
self.x_shape[1], self.n_mid)
model = segnet.SegNetLoss(
model, class_weight=None, train_depth=depth)
opt.setup(model)
# Deregister non-target links from opt
if depth > 1:
model.predictor.remove_link('conv_cls')
for d in range(1, self.n_encdec + 1):
if d != depth:
model.predictor.remove_link('encdec{}'.format(d))
for name, link in model.namedparams():
if depth > 1:
self.assertTrue(
'encdec{}'.format(depth) in name)
else:
self.assertTrue(
'encdec{}'.format(depth) in name or 'conv_cls' in name)
评论列表
文章目录