nn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:WGAN_mnist 作者: rajeswar18 项目源码 文件源码
def resnet_block(input_, filter_size, num_filters,
                 activation=relu, downsample=False,
                 no_output_act=True,
                 use_shortcut=False,
                 use_wn=False,
                 W_init=Normal(0.02),
                 **kwargs):
    """
    Resnet block layer.
    """

    normalization = weight_norm if use_wn else batch_norm

    block = []
    _stride = 2 if downsample else 1
    # conv -> BN -> Relu
    block.append(normalization(conv_layer(input_, filter_size, num_filters,
                                       _stride, 'same', nonlinearity=activation,
                                       W=W_init
    )))
    # Conv -> BN
    block.append(normalization(conv_layer(block[-1], filter_size, num_filters, 1, 'same', nonlinearity=None,
                                       W=W_init)))

    if downsample or use_shortcut:
        shortcut = conv_layer(input_, 1, num_filters, _stride, 'valid', nonlinearity=None)
        block.append(ElemwiseSumLayer([shortcut, block[-1]]))
    else:
        block.append(ElemwiseSumLayer([input_, block[-1]]))

    if not no_output_act:
        block.append(NonlinearityLayer(block[-1], nonlinearity=activation))

    return block[-1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号