def compute_detections_greedy(seg_preds, boxes_preds, num_outputs,
seg_threshold=0.2,
sigma=5e-3, step=0.2, num_iters=20,
dist_threshold=20.0):
mask_flat = tf.reshape(seg_preds[:,:,1], [-1])
boxes_flat = tf.reshape(boxes_preds, [-1, 4])
# TODO: also collect (y,x) coordinates
idxs = tf.where(mask_flat > seg_threshold)[:,0]
boxes = tf.gather(boxes_flat, idxs)
boxes, confidence = refine_boxes(boxes, num_iters, step, sigma)
num_boxes = tf.shape(boxes)[0]
dists = tf.nn.relu(nnutil.pairwise_distance(boxes / sigma))
weights = tf.exp(-dists)
def _next_detection(prev, i):
_, _, presence = prev
confidence_curr = tf.reduce_sum(weights * presence, [1], True)
idx = tf.to_int32(tf.argmax(confidence_curr, 0)[0])
mask = tf.to_float(tf.gather(dists, idx) > dist_threshold)[:,tf.newaxis]
presence = presence * mask
confidence = tf.gather(confidence_curr, idx)[0]
return idx, confidence, presence
idxs, confidence, presences = tf.scan(_next_detection,
tf.range(0, num_outputs),
initializer=(0,
0.0,
tf.ones([num_boxes,1])))
return tf.gather(boxes, idxs), confidence
评论列表
文章目录