def train(self, training_data, trees=100,rf_out=None):
# Use CNN to extract features
self.cnn.set_intermediate(self.feature_layer)
features = self.extract_features(training_data)
# Create random forest
self.rf = RandomForestClassifier(n_estimators=trees, class_weight='balanced_subsample')
X_train = features['y_pred'] # inputs to train the random forest
y_train = np.asarray(features['y_true']) # ground truth for random forest
print "Training RF..."
self.rf.fit(X_train, y_train)
if rf_out:
joblib.dump(self.rf, rf_out)
return self.rf, X_train, y_train
评论列表
文章目录