def extract_patches_fn(image: tf.Tensor, patch_shape: list, offsets) -> tf.Tensor:
"""
:param image: tf.Tensor
:param patch_shape: [h, w]
:param offsets: tuple between 0 and 1
:return: patches [batch_patches, h, w, c]
"""
with tf.name_scope('patch_extraction'):
h, w = patch_shape
c = image.get_shape()[-1]
offset_h = tf.cast(tf.round(offsets[0] * h // 2), dtype=tf.int32)
offset_w = tf.cast(tf.round(offsets[1] * w // 2), dtype=tf.int32)
offset_img = image[offset_h:, offset_w:, :]
offset_img = offset_img[None, :, :, :]
patches = tf.extract_image_patches(offset_img, ksizes=[1, h, w, 1], strides=[1, h // 2, w // 2, 1],
rates=[1, 1, 1, 1], padding='VALID')
patches_shape = tf.shape(patches)
return tf.reshape(patches, [tf.reduce_prod(patches_shape[0:3]), h, w, int(c)]) # returns [batch_patches, h, w, c]
评论列表
文章目录