def __init__(self, n_encdec=4, n_classes=12, in_channel=3, n_mid=64):
assert n_encdec >= 1
w = math.sqrt(2)
super(SegNet, self).__init__(
conv_cls=L.Convolution2D(n_mid, n_classes, 1, 1, 0, w))
# Create and add EncDecs
for i in six.moves.range(1, n_encdec + 1):
name = 'encdec{}'.format(i)
self.add_link(name, EncDec(n_mid if i > 1 else in_channel, n_mid))
for d in six.moves.range(1, n_encdec):
encdec = getattr(self, 'encdec{}'.format(d))
encdec.inside = getattr(self, 'encdec{}'.format(d + 1))
setattr(self, 'encdec{}'.format(d), encdec)
self.n_encdec = n_encdec
self.n_classes = n_classes
self.train = True
评论列表
文章目录