def pool(inputs, window_size, pooling_type, padding, strides=(1, 1)):
"""Pooling (supports "LEFT")."""
with tf.name_scope("pool", [inputs]):
static_shape = inputs.get_shape()
if not static_shape or len(static_shape) != 4:
raise ValueError("Inputs to conv must have statically known rank 4.")
# Add support for left padding.
if padding == "LEFT":
assert window_size[0] % 2 == 1 and window_size[1] % 2 == 1
if len(static_shape) == 3:
width_padding = 2 * (window_size[1] // 2)
padding_ = [[0, 0], [width_padding, 0], [0, 0]]
else:
height_padding = 2 * (window_size[0] // 2)
cond_padding = tf.cond(
tf.equal(shape_list(inputs)[2], 1), lambda: tf.constant(0),
lambda: tf.constant(2 * (window_size[1] // 2)))
width_padding = 0 if static_shape[2] == 1 else cond_padding
padding_ = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]]
inputs = tf.pad(inputs, padding_)
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
padding = "VALID"
return tf.nn.pool(inputs, window_size, pooling_type, padding, strides=strides)
评论列表
文章目录