def testRandomDistort(self):
"""Tests the integrity of the return values of random_distortion.
"""
im_shape = (600, 900, 3)
config = self._random_distort_config
total_boxes = 5
label = 3
image, bboxes = self._get_image_with_boxes(im_shape, total_boxes)
# Add a label to each bbox.
bboxes_w_label = tf.concat(
[
bboxes,
tf.fill((bboxes.shape[0], 1), label)
],
axis=1
)
ret_image, ret_bboxes = self._random_distort(
image, config, bboxes_w_label
)
# Assertions
self.assertEqual(im_shape, ret_image.shape)
self.assertAllEqual(
bboxes, ret_bboxes[:, :4]
)
评论列表
文章目录