base.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:dataset 作者: analysiscenter 项目源码 文件源码
def _dynamic_crop(cls, inputs, static_shape, dynamic_shape, data_format='channels_last'):
        input_shape = cls.spatial_shape(inputs, data_format, True)
        n_channels = cls.num_channels(inputs, data_format)
        if data_format == 'channels_last':
            slice_size = [(-1,), dynamic_shape, (n_channels,)]
            output_shape = [None] * (len(static_shape) + 1) + [n_channels]
        else:
            slice_size = [(-1, n_channels), dynamic_shape]
            output_shape = [None, n_channels] + [None] * len(static_shape)

        begin = [0] * len(inputs.get_shape().as_list())
        size = tf.concat(slice_size, axis=0)
        cond = tf.reduce_sum(tf.abs(input_shape - dynamic_shape)) > 0
        x = tf.cond(cond, lambda: tf.slice(inputs, begin=begin, size=size), lambda: inputs)
        x.set_shape(output_shape)
        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号