def format_input_left_padding(inputs, **kwargs):
static_shape = inputs.get_shape()
if not static_shape or len(static_shape) != 4:
raise ValueError(
"Inputs to conv must have statically known rank 4. Shape: " + str(static_shape))
dilation = (1, 1)
assert kwargs['filter_size'] is not None
filter_size = kwargs['filter_size']
if isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
if "dilation" in kwargs:
dilation_rate = kwargs["dilation"]
assert filter_size[0] % 2 == 1 and filter_size[1] % 2 == 1
height_padding = 2 * (filter_size[0] // 2) * dilation[0]
cond_padding = tf.cond(
tf.equal(tf.shape(inputs)[2], 1), lambda: tf.constant(0),
lambda: tf.constant(2 * (filter_size[1] // 2) * dilation[1]))
width_padding = 0 if static_shape[2] == 1 else cond_padding
padding = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]]
inputs = tf.pad(inputs, padding)
# Set middle two dimensions to None to prevent convolution from complaining
inputs.set_shape([static_shape[0], None, None, static_shape[3]])
kwargs["padding"] = "VALID"
return inputs, kwargs
评论列表
文章目录