def train_svms():
if not os.path.isfile('models/fine_tune.model.index'):
print('models/fine_tune.model doesn\'t exist.')
return
net = create_alexnet()
model = tflearn.DNN(net)
model.load('models/fine_tune.model')
train_file_dir = 'svm_train/'
flist = os.listdir(train_file_dir)
svms = []
for train_file in flist:
if "pkl" in train_file:
continue
X, Y = generate_single_svm_train_data(train_file_dir + train_file)
train_features = []
for i in X:
feats = model.predict([i])
train_features.append(feats[0])
print("feature dimension of fitting: {}".format(np.shape(train_features)))
clf = svm.LinearSVC()
clf.fit(train_features, Y)
svms.append(clf)
joblib.dump(svms, 'models/train_svm.model')
评论列表
文章目录