def global_pool(inp, kind='avg', keep_dims=False, name=None):
if kind not in ['max', 'avg']:
raise ValueError('Only global avg or max pool is allowed, but'
'you requested {}.'.format(kind))
if name is None:
name = 'global_{}_pool'.format(kind)
h, w = inp.get_shape().as_list()[1:3]
out = getattr(tf.nn, kind + '_pool')(inp,
ksize=[1,h,w,1],
strides=[1,1,1,1],
padding='VALID')
if keep_dims:
output = tf.identity(out, name=name)
else:
output = tf.reshape(out, [out.get_shape().as_list()[0], -1], name=name)
return output
评论列表
文章目录