tensorflow_backend.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:SGAITagger 作者: zhiweiuu 项目源码 文件源码
def extract_image_patches(X, ksizes, ssizes, border_mode="same", dim_ordering="tf"):
    '''
    Extract the patches from an image
    Parameters
    ----------
    X : The input image
    ksizes : 2-d tuple with the kernel size
    ssizes : 2-d tuple with the strides size
    border_mode : 'same' or 'valid'
    dim_ordering : 'tf' or 'th'
    Returns
    -------
    The (k_w,k_h) patches extracted
    TF ==> (batch_size,w,h,k_w,k_h,c)
    TH ==> (batch_size,w,h,c,k_w,k_h)
    '''
    kernel = [1, ksizes[0], ksizes[1], 1]
    strides = [1, ssizes[0], ssizes[1], 1]
    padding = _preprocess_border_mode(border_mode)
    if dim_ordering == "th":
        X = KTF.permute_dimensions(X, (0, 2, 3, 1))
    bs_i, w_i, h_i, ch_i = KTF.int_shape(X)
    patches = tf.extract_image_patches(X, kernel, strides, [1, 1, 1, 1], padding)
    # Reshaping to fit Theano
    bs, w, h, ch = KTF.int_shape(patches)
    patches = tf.reshape(tf.transpose(tf.reshape(patches, [bs, w, h, -1, ch_i]), [0, 1, 2, 4, 3]),
                         [bs, w, h, ch_i, ksizes[0], ksizes[1]])
    if dim_ordering == "tf":
        patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
    return patches
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号