def extract_image_patches(X, ksizes, ssizes, border_mode="same", dim_ordering="tf"):
'''
Extract the patches from an image
Parameters
----------
X : The input image
ksizes : 2-d tuple with the kernel size
ssizes : 2-d tuple with the strides size
border_mode : 'same' or 'valid'
dim_ordering : 'tf' or 'th'
Returns
-------
The (k_w,k_h) patches extracted
TF ==> (batch_size,w,h,k_w,k_h,c)
TH ==> (batch_size,w,h,c,k_w,k_h)
'''
kernel = [1, ksizes[0], ksizes[1], 1]
strides = [1, ssizes[0], ssizes[1], 1]
padding = _preprocess_border_mode(border_mode)
if dim_ordering == "th":
X = KTF.permute_dimensions(X, (0, 2, 3, 1))
bs_i, w_i, h_i, ch_i = KTF.int_shape(X)
patches = tf.extract_image_patches(X, kernel, strides, [1, 1, 1, 1], padding)
# Reshaping to fit Theano
bs, w, h, ch = KTF.int_shape(patches)
patches = tf.reshape(tf.transpose(tf.reshape(patches, [bs, w, h, -1, ch_i]), [0, 1, 2, 4, 3]),
[bs, w, h, ch_i, ksizes[0], ksizes[1]])
if dim_ordering == "tf":
patches = KTF.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])
return patches
评论列表
文章目录