def sparsetensor2list(labels_st, batch_size):
"""Convert labels from sparse tensor to list.
Args:
labels_st: A SparseTensor of labels
batch_size (int): the size of mini-batch
Returns:
labels (list): list of np.ndarray, size of `[B]`. Each element is a
sequence of target labels of an input.
"""
if isinstance(labels_st, tf.SparseTensorValue):
# Output of TensorFlow
indices = labels_st.indices
values = labels_st.values
else:
# labels_st is expected to be a list [indices, values, shape]
indices = labels_st[0]
values = labels_st[1]
if batch_size == 1:
return values.reshape((1, -1))
labels = []
batch_boundary = np.where(indices[:, 1] == 0)[0]
# TODO: Some errors occurred when ctc models do not output any labels
# print(batch_boundary)
# if len(batch_boundary) != batch_size:
# batch_boundary = np.array(batch_boundary.tolist() + [max(batch_boundary) + 1])
# print(indices)
for i in range(batch_size - 1):
label_each_utt = values[batch_boundary[i]:batch_boundary[i + 1]]
labels.append(label_each_utt)
# Last label
labels.append(values[batch_boundary[-1]:])
return labels
sparsetensor.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录