def aePredict(self, graph):
self.initCG()
graph = graph.cleaned()
carriers = self.getLSTMFeatures(graph.nodes)
beamconf = AEBeamConfiguration(len(graph.nodes), 1, np.array(graph.heads), self.stack_features, self.buffer_features)
beamconf.initconf(0, self.root_first)
while not beamconf.isComplete(0):
valid = beamconf.validTransitions(0)
if np.count_nonzero(valid) < 1:
break
scores, exprs = self._aeEvaluate(beamconf.extractFeatures(0), carriers)
best, bestscore = max(((i, s) for i, s in enumerate(scores) if valid[i]), key=itemgetter(1))
beamconf.makeTransition(0, best)
graph.heads = [i if i > 0 else 0 for i in list(beamconf.getHeads(0))]
return graph
评论列表
文章目录