model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def global_pool(inp, kind='avg', keep_dims=False, name=None):
    if kind not in ['max', 'avg']:
        raise ValueError('Only global avg or max pool is allowed, but'
                            'you requested {}.'.format(kind))
    if name is None:
        name = 'global_{}_pool'.format(kind)
    h, w = inp.get_shape().as_list()[1:3]
    out = getattr(tf.nn, kind + '_pool')(inp,
                                    ksize=[1,h,w,1],
                                    strides=[1,1,1,1],
                                    padding='VALID')
    if keep_dims:
        output = tf.identity(out, name=name)
    else:
        output = tf.reshape(out, [out.get_shape().as_list()[0], -1], name=name)

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号