special_fn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def format_input_left_padding(inputs, **kwargs):
    static_shape = inputs.get_shape()
    if not static_shape or len(static_shape) != 4:
        raise ValueError(
            "Inputs to conv must have statically known rank 4. Shape: " + str(static_shape))
    dilation = (1, 1)
    assert kwargs['filter_size'] is not None
    filter_size = kwargs['filter_size']
    if isinstance(filter_size, int):
        filter_size = [filter_size, filter_size]
    if "dilation" in kwargs:
        dilation_rate = kwargs["dilation"]
    assert filter_size[0] % 2 == 1 and filter_size[1] % 2 == 1
    height_padding = 2 * (filter_size[0] // 2) * dilation[0]
    cond_padding = tf.cond(
        tf.equal(tf.shape(inputs)[2], 1), lambda: tf.constant(0),
        lambda: tf.constant(2 * (filter_size[1] // 2) * dilation[1]))
    width_padding = 0 if static_shape[2] == 1 else cond_padding
    padding = [[0, 0], [height_padding, 0], [width_padding, 0], [0, 0]]
    inputs = tf.pad(inputs, padding)
    # Set middle two dimensions to None to prevent convolution from complaining
    inputs.set_shape([static_shape[0], None, None, static_shape[3]])
    kwargs["padding"] = "VALID"
    return inputs, kwargs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号