def arg_closest_anchor(bboxes, anchors):
"""Find the closest anchor. Box Format [ymin, xmin, ymax, xmax]
"""
num_anchors = anchors.get_shape().as_list()[0]
num_bboxes = tf.shape(bboxes)[0]
_indices = tf.reshape(tf.range(num_bboxes), shape=[-1, 1])
_indices = tf.reshape(tf.stack([_indices] * num_anchors, axis=1), shape=[-1, 1])
bboxes_m = tf.gather_nd(bboxes, _indices)
# bboxes_m = tf.Print(bboxes_m, [bboxes_m], "bboxes_m", summarize=100)
anchors_m = tf.tile(anchors, [num_bboxes, 1])
# anchors_m = tf.Print(anchors_m, [anchors_m], "anchors_m", summarize=100)
square_dist = tf.squared_difference(bboxes_m[:, 0], anchors_m[:, 0]) + \
tf.squared_difference(bboxes_m[:, 1], anchors_m[:, 1]) + \
tf.squared_difference(bboxes_m[:, 2], anchors_m[:, 2]) + \
tf.squared_difference(bboxes_m[:, 3], anchors_m[:, 3])
square_dist = tf.reshape(square_dist, shape=[num_bboxes, num_anchors])
# square_dist = tf.Print(square_dist, [square_dist], "square_dist", summarize=100)
indices = tf.arg_min(square_dist, dimension=1)
return indices
评论列表
文章目录