layers_export.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Fabrik 作者: Cloud-CV 项目源码 文件源码
def pooling(layer, layer_in, layerId):
    poolMap = {
        ('1D', 'MAX'): MaxPooling1D,
        ('2D', 'MAX'): MaxPooling2D,
        ('3D', 'MAX'): MaxPooling3D,
        ('1D', 'AVE'): AveragePooling1D,
        ('2D', 'AVE'): AveragePooling2D,
        ('3D', 'AVE'): AveragePooling3D,
    }
    out = {}
    layer_type = layer['params']['layer_type']
    pool_type = layer['params']['pool']
    padding = get_padding(layer)
    if (layer_type == '1D'):
        strides = layer['params']['stride_w']
        kernel = layer['params']['kernel_w']
        if (padding == 'custom'):
            p_w = layer['params']['pad_w']
            out[layerId + 'Pad'] = ZeroPadding1D(padding=p_w)(*layer_in)
            padding = 'valid'
            layer_in = [out[layerId + 'Pad']]
    elif (layer_type == '2D'):
        strides = (layer['params']['stride_h'], layer['params']['stride_w'])
        kernel = (layer['params']['kernel_h'], layer['params']['kernel_w'])
        if (padding == 'custom'):
            p_h, p_w = layer['params']['pad_h'], layer['params']['pad_w']
            out[layerId + 'Pad'] = ZeroPadding2D(padding=(p_h, p_w))(*layer_in)
            padding = 'valid'
            layer_in = [out[layerId + 'Pad']]
    else:
        strides = (layer['params']['stride_h'], layer['params']['stride_w'],
                   layer['params']['stride_d'])
        kernel = (layer['params']['kernel_h'], layer['params']['kernel_w'],
                  layer['params']['kernel_d'])
        if (padding == 'custom'):
            p_h, p_w, p_d = layer['params']['pad_h'], layer['params']['pad_w'],\
                            layer['params']['pad_d']
            out[layerId + 'Pad'] = ZeroPadding3D(padding=(p_h, p_w, p_d))(*layer_in)
            padding = 'valid'
            layer_in = [out[layerId + 'Pad']]
    out[layerId] = poolMap[(layer_type, pool_type)](pool_size=kernel, strides=strides, padding=padding)(
                                                    *layer_in)
    return out


# ********** Locally-connected Layers **********
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号