def ctc_label_dense_to_sparse(labels, label_lengths, init_len):
"""
TODO: the number of non-zeros in every row of 'labels' must less than the corresponding value in 'label_lengths'
"""
label_shape = labels.get_shape().as_list()
len_shape = label_lengths.get_shape().as_list()[0]
batch = label_shape[0]
assert(batch == len_shape)
max_len = tf.reduce_max(init_len)
cur_len = tf.constant(2.0)
cur_len = tf.tile(tf.expand_dims(cur_len,axis=-1),[batch])
mask = tf.cast(tf.sequence_mask(label_lengths,max_len), tf.int32)
#labels_split, buf = tf.split(labels, [max_len,-1], axis=1)
#buf = tf.reduce_sum(buf)
#tf.summary.scalar('buf', buf)
labels = tf.multiply(labels, mask)
#min_len = tf.arg_min(label_lengths, dimension=0)
#mask = tf.fill(label_shape, 0)
where_val = tf.less(tf.constant(0), labels)
indices = tf.where(where_val)
vals_sparse = tf.gather_nd(labels, indices)
return indices, vals_sparse, tf.shape(labels), cur_len, mask
python类arg_min()的实例源码
def arg_closest_anchor(bboxes, anchors):
"""Find the closest anchor. Box Format [ymin, xmin, ymax, xmax]
"""
num_anchors = anchors.get_shape().as_list()[0]
num_bboxes = tf.shape(bboxes)[0]
_indices = tf.reshape(tf.range(num_bboxes), shape=[-1, 1])
_indices = tf.reshape(tf.stack([_indices] * num_anchors, axis=1), shape=[-1, 1])
bboxes_m = tf.gather_nd(bboxes, _indices)
# bboxes_m = tf.Print(bboxes_m, [bboxes_m], "bboxes_m", summarize=100)
anchors_m = tf.tile(anchors, [num_bboxes, 1])
# anchors_m = tf.Print(anchors_m, [anchors_m], "anchors_m", summarize=100)
square_dist = tf.squared_difference(bboxes_m[:, 0], anchors_m[:, 0]) + \
tf.squared_difference(bboxes_m[:, 1], anchors_m[:, 1]) + \
tf.squared_difference(bboxes_m[:, 2], anchors_m[:, 2]) + \
tf.squared_difference(bboxes_m[:, 3], anchors_m[:, 3])
square_dist = tf.reshape(square_dist, shape=[num_bboxes, num_anchors])
# square_dist = tf.Print(square_dist, [square_dist], "square_dist", summarize=100)
indices = tf.arg_min(square_dist, dimension=1)
return indices