def calc_center_bb(binary_class_mask):
""" Returns the center of mass coordinates for the given binary_class_mask. """
with tf.variable_scope('calc_center_bb'):
binary_class_mask = tf.cast(binary_class_mask, tf.int32)
binary_class_mask = tf.equal(binary_class_mask, 1)
s = binary_class_mask.get_shape().as_list()
if len(s) == 4:
binary_class_mask = tf.squeeze(binary_class_mask, [3])
s = binary_class_mask.get_shape().as_list()
assert len(s) == 3, "binary_class_mask must be 3D."
assert (s[0] < s[1]) and (s[0] < s[2]), "binary_class_mask must be [Batch, Width, Height]"
# my meshgrid
x_range = tf.expand_dims(tf.range(s[1]), 1)
y_range = tf.expand_dims(tf.range(s[2]), 0)
X = tf.tile(x_range, [1, s[2]])
Y = tf.tile(y_range, [s[1], 1])
bb_list = list()
center_list = list()
crop_size_list = list()
for i in range(s[0]):
X_masked = tf.cast(tf.boolean_mask(X, binary_class_mask[i, :, :]), tf.float32)
Y_masked = tf.cast(tf.boolean_mask(Y, binary_class_mask[i, :, :]), tf.float32)
x_min = tf.reduce_min(X_masked)
x_max = tf.reduce_max(X_masked)
y_min = tf.reduce_min(Y_masked)
y_max = tf.reduce_max(Y_masked)
start = tf.stack([x_min, y_min])
end = tf.stack([x_max, y_max])
bb = tf.stack([start, end], 1)
bb_list.append(bb)
center_x = 0.5*(x_max + x_min)
center_y = 0.5*(y_max + y_min)
center = tf.stack([center_x, center_y], 0)
center = tf.cond(tf.reduce_all(tf.is_finite(center)), lambda: center,
lambda: tf.constant([160.0, 160.0]))
center.set_shape([2])
center_list.append(center)
crop_size_x = x_max - x_min
crop_size_y = y_max - y_min
crop_size = tf.expand_dims(tf.maximum(crop_size_x, crop_size_y), 0)
crop_size = tf.cond(tf.reduce_all(tf.is_finite(crop_size)), lambda: crop_size,
lambda: tf.constant([100.0]))
crop_size.set_shape([1])
crop_size_list.append(crop_size)
bb = tf.stack(bb_list)
center = tf.stack(center_list)
crop_size = tf.stack(crop_size_list)
return center, bb, crop_size
评论列表
文章目录