def train(train_dataTables, human_marks):
global classifier
samples =[]
target = []
for nn, dataTable in enumerate(train_dataTables):
for i in xrange(dataTable.row):
for j in xrange(dataTable.col):
mention = dataTable[i][j]
if mention.cid == -1:
continue
eids = dataTable.get_eids(i, j)
words = dataTable.get_words(i, j)
entites = dataTable.get_entities(i ,j)
true_id = human_marks[nn][i][j]['id']
for ii, entity in enumerate(mention.candidates):
prior = entity.popular
SR = mention.getSR(ii, entites)
res = int(true_id == entity.id)
samples.append([prior, SR])
target.append(res)
from sklearn import svm
classifier = svm.SVC(probability=True)
classifier.fit(samples, target)
评论列表
文章目录