theano_backend.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:odin_old 作者: trungnt13 项目源码 文件源码
def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
           dim_ordering='th', pool_mode='max'):
    # ====== dim ordering ====== #
    if dim_ordering not in {'th', 'tf'}:
        raise Exception('Unknown dim_ordering ' + str(dim_ordering))
    if dim_ordering == 'tf':
        x = x.dimshuffle((0, 3, 1, 2))
    # ====== border mode ====== #
    if border_mode == 'same':
        w_pad = pool_size[0] - 2 if pool_size[0] % 2 == 1 else pool_size[0] - 1
        h_pad = pool_size[1] - 2 if pool_size[1] % 2 == 1 else pool_size[1] - 1
        padding = (w_pad, h_pad)
    elif border_mode == 'valid':
        padding = (0, 0)
    elif isinstance(border_mode, (tuple, list)):
        padding = tuple(border_mode)
    else:
        raise Exception('Invalid border mode: ' + str(border_mode))

    # ====== pooling ====== #
    if _on_gpu() and dnn.dnn_available():
        pool_out = dnn.dnn_pool(x, pool_size,
                                stride=strides,
                                mode=pool_mode,
                                pad=padding)
    else: # CPU veresion support by theano
        pool_out = pool.pool_2d(x, ds=pool_size, st=strides,
                                ignore_border=True,
                                padding=padding,
                                mode=pool_mode)

    if dim_ordering == 'tf':
        pool_out = pool_out.dimshuffle((0, 2, 3, 1))
    return pool_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号