def __init__(self, depth=18, alpha=16, start_channel=16, skip=False):
super(PyramidNet, self).__init__()
channel_diff = float(alpha) / depth
channel = start_channel
links = [('bconv1', BatchConv2D(3, channel, 3, 1, 1))]
skip_size = depth * 3 - 3
for i in six.moves.range(depth):
if skip:
skip_ratio = float(i) / skip_size * 0.5
else:
skip_ratio = 0
in_channel = channel
channel += channel_diff
links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), skip_ratio=skip_ratio)))
in_channel = channel
channel += channel_diff
links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), stride=2)))
for i in six.moves.range(depth - 1):
if skip:
skip_ratio = float(i + depth) / skip_size * 0.5
else:
skip_ratio = 0
in_channel = channel
channel += channel_diff
links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), skip_ratio=skip_ratio)))
in_channel = channel
channel += channel_diff
links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), stride=2)))
for i in six.moves.range(depth - 1):
if skip:
skip_ratio = float(i + depth * 2 - 1) / skip_size * 0.5
else:
skip_ratio = 0
in_channel = channel
channel += channel_diff
links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), skip_ratio=skip_ratio)))
links.append(('bn{}'.format(len(links)), L.BatchNormalization(int(round(channel)))))
links.append(('_relu{}'.format(len(links)), F.ReLU()))
links.append(('_apool{}'.format(len(links)), F.AveragePooling2D(8, 1, 0, False)))
links.append(('fc{}'.format(len(links)), L.Linear(int(round(channel)), 10)))
for name, f in links:
if not name.startswith('_'):
self.add_link(*(name, f))
self.layers = links
评论列表
文章目录