def classifier_train(feature_matrix_0, feature_matrix_1, algorithm = 'SVM'):
"""
Trains a binary classifier using the SVM algorithm with the following parameters
Arguments
feature_matrix_0: Matrix with examples for Class 0
feature_matrix_0: Matrix with examples for Class 1
algorithm: Currently only SVM is supported
Outputs
classfier: trained classifier (scikit object)
mu_ft, std_ft: normalization parameters for the data
"""
# Create vector Y (class labels)
class0 = np.zeros((feature_matrix_0.shape[0],1))
class1 = np.ones((feature_matrix_1.shape[0],1))
# Concatenate feature matrices and their respective labels
y = np.concatenate((class0, class1),axis=0)
features_all = np.concatenate((feature_matrix_0, feature_matrix_1),axis=0)
# Normalize inputs
mu_ft = np.mean(features_all)
std_ft = np.std(features_all)
X = (features_all - mu_ft) / std_ft
# Train SVM, using default parameters
classifier = svm.SVC()
classifier.fit(X, y)
return classifier, mu_ft, std_ft
评论列表
文章目录