def residual_block(resnet_in, num_styles=None, num_filters=None, filter_size=3, stride=1):
if num_filters == None:
num_filters = resnet_in.output_shape[1]
conv1 = style_conv_block(resnet_in, num_styles, num_filters, filter_size, stride)
conv2 = style_conv_block(conv1, num_styles, num_filters, filter_size, stride, linear)
res_block = ElemwiseSumLayer([conv2, resnet_in])
return res_block
评论列表
文章目录