def _get_labels_feed_item(label_list, max_time):
"""
Generate the tensor from 'label_list' to feed as labels into the network
Args:
label_list: a list of encoded labels (ints)
max_time: the maximum time length of `label_list`
Returns: the SparseTensorValue to feed into the network
"""
label_shape = np.array([len(label_list), max_time], dtype=np.int)
label_indices = []
label_values = []
for labelIdx, label in enumerate(label_list):
for idIdx, identifier in enumerate(label):
label_indices.append([labelIdx, idIdx])
label_values.append(identifier)
label_indices = np.array(label_indices, dtype=np.int)
label_values = np.array(label_values, dtype=np.int)
return tf.SparseTensorValue(label_indices, label_values, label_shape)
评论列表
文章目录