def bbox_to_mask(bbox, region_size, output_size, dtype=tf.float32):
"""Creates a binary mask of size `region_size` where rectangle given by
`bbox` is filled with ones and the rest is zeros. Finally, the binary mask
is resized to `output_size` with bilinear interpolation.
:param bbox: tensor of shape (..., 4)
:param region_size: tensor of shape (..., 2)
:param output_size: 2-tuple of ints
:param dtype: tf.dtype
:return: a tensor of shape = (..., output_size)
"""
shape = tf.concat(axis=0, values=(tf.shape(bbox)[:-1], output_size))
bbox = tf.reshape(bbox, (-1, 4))
region_size = tf.reshape(region_size, (-1, 2))
def create_mask(args):
yy, region_size = args
return _bbox_to_mask_fixed_size(yy, region_size, output_size, dtype)
mask = tf.map_fn(create_mask, (bbox, region_size), dtype=dtype)
return tf.reshape(mask, shape)
评论列表
文章目录