def label_metadata(label_matrix, label_col):
# Check whether the column value is given as index (number) or name (string)
try:
label_col = int(label_col)
# If given as number, take the name of the column out of it
label_col = label_matrix.columns[label_col]
except ValueError:
pass
import pandas as pd
# Get the unique classes in the given column, and how many of them are there
unique_classes = pd.unique(label_matrix[label_col].ravel())
#num_classes = unique_classes.shape[0]
# Map the unique n classes with a number from 0 to n
label_map = pd.DataFrame({label_col: unique_classes, label_col+'_id':range(len(unique_classes))})
# Replace the given column's values with the mapped equivalent
mapped_labels = label_matrix.replace(label_map[[0]].values.tolist(), label_map[[1]].values.tolist())
# Return the mapped labels as numpy list and the label map (unique classes and number can be obtained from map)
return np.reshape(mapped_labels[[label_col]].values, (mapped_labels.shape[0],)), np.asarray(label_map) #, unique_classes, num_classes
评论列表
文章目录