def refine_boxes(boxes, num_iters, step, sigma):
assert num_iters > 1
def iteration(prev, i):
state_prev, _ = prev
features = state_prev / sigma
dists = tf.nn.relu(nnutil.pairwise_distance(features))
weights = tf.exp(-dists)
confidence = tf.reduce_sum(weights, [1], True)
weights = weights / confidence
state_up = tf.matmul(weights, state_prev)
return (1.0 - step) * state_prev + step * state_up, confidence
states = tf.scan(iteration,
tf.range(0, num_iters),
initializer=(boxes, boxes[:,0:1]))
return states[0][-1], states[1][-1]
评论列表
文章目录