tensorflow_backend.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码
def extract_image_patches(x, ksizes, ssizes, padding='same',
                          data_format='channels_last'):
    '''
    Extract the patches from an image
    # Parameters

        x : The input image
        ksizes : 2-d tuple with the kernel size
        ssizes : 2-d tuple with the strides size
        padding : 'same' or 'valid'
        data_format : 'channels_last' or 'channels_first'

    # Returns
        The (k_w,k_h) patches extracted
        TF ==> (batch_size,w,h,k_w,k_h,c)
        TH ==> (batch_size,w,h,c,k_w,k_h)
    '''
    kernel = [1, ksizes[0], ksizes[1], 1]
    strides = [1, ssizes[0], ssizes[1], 1]
    padding = _preprocess_padding(padding)
    if data_format == 'channels_first':
        x = KTF.permute_dimensions(x, (0, 2, 3, 1))
    bs_i, w_i, h_i, ch_i = KTF.int_shape(x)
    patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1],
                                       padding)
    # Reshaping to fit Theano
    bs, w, h, ch = KTF.int_shape(patches)
    patches = tf.reshape(tf.transpose(tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i]), [0, 1, 2, 4, 3]),
                         [-1, w, h, ch_i, ksizes[0], ksizes[1]])
    if data_format == 'channels_last':
        patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
    return patches
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号