def _generate_labels(self, overlaps):
labels = tf.Variable(tf.ones(shape=(tf.shape(overlaps)[0],), dtype=tf.float32) * -1, trainable=False,
validate_shape=False)
gt_max_overlaps = tf.arg_max(overlaps, dimension=0)
anchor_max_overlaps = tf.arg_max(overlaps, dimension=1)
mask = tf.one_hot(anchor_max_overlaps, tf.shape(overlaps)[1], on_value=True, off_value=False)
max_overlaps = tf.boolean_mask(overlaps, mask)
if self._debug:
max_overlaps = tf.Print(max_overlaps, [max_overlaps])
labels = tf.scatter_update(labels, gt_max_overlaps, tf.ones((tf.shape(gt_max_overlaps)[0],)))
# TODO: extract config object
over_threshold_mask = tf.reshape(tf.where(max_overlaps > 0.5), (-1,))
if self._debug:
over_threshold_mask = tf.Print(over_threshold_mask, [over_threshold_mask], message='over threshold index : ')
labels = tf.scatter_update(labels, over_threshold_mask, tf.ones((tf.shape(over_threshold_mask)[0],)))
# TODO: support clobber positive in the origin implement
below_threshold_mask = tf.reshape(tf.where(max_overlaps < 0.3), (-1,))
if self._debug:
below_threshold_mask = tf.Print(below_threshold_mask, [below_threshold_mask], message='below threshold index : ')
labels = tf.scatter_update(labels, below_threshold_mask, tf.zeros((tf.shape(below_threshold_mask)[0],)))
return labels
评论列表
文章目录