def train(self, depgraphs, modelfile):
"""
:param depgraphs : list of DependencyGraph as the training data
:type depgraphs : DependencyGraph
:param modelfile : file name to save the trained model
:type modelfile : str
"""
try:
input_file = tempfile.NamedTemporaryFile(
prefix='transition_parse.train',
dir=tempfile.gettempdir(),
delete=False)
if self._algorithm == self.ARC_STANDARD:
self._create_training_examples_arc_std(depgraphs, input_file)
else:
self._create_training_examples_arc_eager(depgraphs, input_file)
input_file.close()
# Using the temporary file to train the libsvm classifier
x_train, y_train = load_svmlight_file(input_file.name)
# The parameter is set according to the paper:
# Algorithms for Deterministic Incremental Dependency Parsing by Joakim Nivre
# Todo : because of probability = True => very slow due to
# cross-validation. Need to improve the speed here
model = svm.SVC(
kernel='poly',
degree=2,
coef0=0,
gamma=0.2,
C=0.5,
verbose=True,
probability=True)
model.fit(x_train, y_train)
# Save the model to file name (as pickle)
pickle.dump(model, open(modelfile, 'wb'))
finally:
remove(input_file.name)
评论列表
文章目录