def resnet_block(input_, filter_size, num_filters,
activation=relu, downsample=False,
no_output_act=True,
use_shortcut=False,
use_wn=False,
W_init=Normal(0.02),
**kwargs):
"""
Resnet block layer.
"""
normalization = weight_norm if use_wn else batch_norm
block = []
_stride = 2 if downsample else 1
# conv -> BN -> Relu
block.append(normalization(conv_layer(input_, filter_size, num_filters,
_stride, 'same', nonlinearity=activation,
W=W_init
)))
# Conv -> BN
block.append(normalization(conv_layer(block[-1], filter_size, num_filters, 1, 'same', nonlinearity=None,
W=W_init)))
if downsample or use_shortcut:
shortcut = conv_layer(input_, 1, num_filters, _stride, 'valid', nonlinearity=None)
block.append(ElemwiseSumLayer([shortcut, block[-1]]))
else:
block.append(ElemwiseSumLayer([input_, block[-1]]))
if not no_output_act:
block.append(NonlinearityLayer(block[-1], nonlinearity=activation))
return block[-1]
评论列表
文章目录