def get_one_hot(in_matrix):
"""
Reformat truth matrix to same size as the output of the dense network.
Args:
in_matrix: the categorized 1D matrix
Returns: a one-hot matrix representing the categorized matrix
"""
if in_matrix.dtype.name == 'category':
custum_array = in_matrix.cat.codes
elif isinstance(in_matrix, np.ndarray):
custum_array = in_matrix
else:
raise ValueError("Input matrix cannot be converted.")
lb = LabelBinarizer()
return np.array(lb.fit_transform(custum_array), dtype='float32')
评论列表
文章目录