retina_net.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:qtim_ROP 作者: QTIM-Lab 项目源码 文件源码
def train(self, training_data, trees=100,rf_out=None):

        # Use CNN to extract features
        self.cnn.set_intermediate(self.feature_layer)
        features = self.extract_features(training_data)

        # Create random forest
        self.rf = RandomForestClassifier(n_estimators=trees, class_weight='balanced_subsample')
        X_train = features['y_pred']  # inputs to train the random forest
        y_train = np.asarray(features['y_true'])  # ground truth for random forest

        print "Training RF..."
        self.rf.fit(X_train, y_train)

        if rf_out:
            joblib.dump(self.rf, rf_out)

        return self.rf, X_train, y_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号