def __init__(self, genres, data, type='knn', name='', clf_kwargs=None):
self.logger = get_logger('classifier')
self.display_name = name
self.genres = genres
self.m_genres = { genre:i for i, genre in enumerate(genres) }
self.randstate = np.random.RandomState()
self.scaler = StandardScaler()
clf_kwargs = { } if not clf_kwargs else clf_kwargs
if type in ['svm', 'mlp']:
clf_kwargs['random_state'] = self.randstate
if type == 'knn':
self.proto_clf = KNeighborsClassifier(**clf_kwargs)
elif type == 'svm':
self.proto_clf = SVC(**clf_kwargs)
elif type == 'dtree':
self.proto_clf = DecisionTreeClassifier(**clf_kwargs)
elif type == 'gnb':
self.proto_clf = GaussianNB(**clf_kwargs)
elif type == 'perc':
self.proto_clf = Perceptron(**clf_kwargs)
elif type == 'mlp':
self.proto_clf = MLPClassifier(**clf_kwargs)
elif type == 'ada':
self.proto_clf = AdaBoostClassifier(**clf_kwargs)
else:
raise LookupError('Classifier type "{}" is invalid'.format(type))
self._convert_data(data)
self.logger.info('Classifier: {} (params={})'.format(
self.proto_clf.__class__.__name__,
clf_kwargs
))
评论列表
文章目录