sparsetensor.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow_end2end_speech_recognition 作者: hirofumi0810 项目源码 文件源码
def sparsetensor2list(labels_st, batch_size):
    """Convert labels from sparse tensor to list.
    Args:
        labels_st: A SparseTensor of labels
        batch_size (int): the size of mini-batch
    Returns:
        labels (list): list of np.ndarray, size of `[B]`. Each element is a
            sequence of target labels of an input.
    """
    if isinstance(labels_st, tf.SparseTensorValue):
        # Output of TensorFlow
        indices = labels_st.indices
        values = labels_st.values
    else:
        # labels_st is expected to be a list [indices, values, shape]
        indices = labels_st[0]
        values = labels_st[1]

    if batch_size == 1:
        return values.reshape((1, -1))

    labels = []
    batch_boundary = np.where(indices[:, 1] == 0)[0]

    # TODO: Some errors occurred when ctc models do not output any labels
    # print(batch_boundary)
    # if len(batch_boundary) != batch_size:
    #     batch_boundary = np.array(batch_boundary.tolist() + [max(batch_boundary) + 1])
    # print(indices)

    for i in range(batch_size - 1):
        label_each_utt = values[batch_boundary[i]:batch_boundary[i + 1]]
        labels.append(label_each_utt)
    # Last label
    labels.append(values[batch_boundary[-1]:])

    return labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号