def __init__(self, in_channels, out_channels, ksize=3, fiber_map='id',
stride=1, pad=1, wscale=1, bias=0, nobias=False, use_cudnn=True, initialW=None, initial_bias=None):
assert ksize % 2 == 1
assert pad == (ksize - 1) // 2
super(ResBlock, self).__init__(
bn1=L.BatchNormalization(in_channels),
conv1=L.Convolution2D(in_channels, out_channels, ksize, stride, pad, wscale),
bn2=L.BatchNormalization(out_channels),
conv2=L.Convolution2D(out_channels, out_channels, ksize, 1, pad, wscale),
)
if fiber_map == 'id':
assert in_channels == out_channels
self.fiber_map = F.identity
elif fiber_map == 'linear':
self.add_link('fiber_map', L.Convolution2D(in_channels, out_channels, 1, 2, 0, wscale))
else:
raise ValueError('Unimplemented fiber map {}'.format(fiber_map))
评论列表
文章目录