def predict(self, tx, tm, twx, tcm, tgaze, tlemma = None, tpos = None):
i = 0
pys = []
while i < self.tx.shape[0]:
# j = min(self.x.shape[0], i + self.test_batch_size)
j = i + self.test_batch_size
s_x, s_m, s_wx, s_cm = tx[i: j], tm[i: j], twx[i: j], tcm[i: j]
s_gaze = tgaze[i: j] if self.use_gaze else None
s_lemma = tlemma[i: j] if self.use_lemma else None
s_pos = tpos[i: j] if self.use_pos else None
pys.append(self.test_fn(s_x, s_m, s_wx, s_cm, s_gaze, s_lemma, s_pos))
i = j
py = np.vstack(tuple(pys))
if self.use_crf:
return py.flatten()
else:
return py.argmax(axis = 1)
评论列表
文章目录