train_svms.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:rcnn-with-tflearn 作者: Redoblue 项目源码 文件源码
def train_svms():
    if not os.path.isfile('models/fine_tune.model.index'):
        print('models/fine_tune.model doesn\'t exist.')
        return

    net = create_alexnet()
    model = tflearn.DNN(net)
    model.load('models/fine_tune.model')

    train_file_dir = 'svm_train/'
    flist = os.listdir(train_file_dir)
    svms = []
    for train_file in flist:
        if "pkl" in train_file:
            continue
        X, Y = generate_single_svm_train_data(train_file_dir + train_file)
        train_features = []
        for i in X:
            feats = model.predict([i])
            train_features.append(feats[0])
        print("feature dimension of fitting: {}".format(np.shape(train_features)))
        clf = svm.LinearSVC()
        clf.fit(train_features, Y)
        svms.append(clf)
    joblib.dump(svms, 'models/train_svm.model')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号