sparse_trans_test.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:TF-Speech-Recognition 作者: ZhishengWang 项目源码 文件源码
def ctc_label_dense_to_sparse(labels, label_lengths, init_len):
    """
    TODO: the number of non-zeros in every row of 'labels' must less than the corresponding value in 'label_lengths'
    """
    label_shape = labels.get_shape().as_list()
    len_shape = label_lengths.get_shape().as_list()[0]
    batch = label_shape[0]
    assert(batch == len_shape)
    max_len = tf.reduce_max(init_len)

    cur_len = tf.constant(2.0)
    cur_len = tf.tile(tf.expand_dims(cur_len,axis=-1),[batch])

    mask = tf.cast(tf.sequence_mask(label_lengths,max_len), tf.int32)
    #labels_split, buf  = tf.split(labels, [max_len,-1], axis=1)
    #buf = tf.reduce_sum(buf)
    #tf.summary.scalar('buf', buf)

    labels = tf.multiply(labels, mask)
    #min_len = tf.arg_min(label_lengths, dimension=0)
    #mask = tf.fill(label_shape, 0)

    where_val = tf.less(tf.constant(0), labels)
    indices = tf.where(where_val)

    vals_sparse = tf.gather_nd(labels, indices)
    return indices, vals_sparse, tf.shape(labels), cur_len, mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号