cnn_rnn.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:transfer 作者: kimiyoung 项目源码 文件源码
def predict(self, tx, tm, twx, tcm, tgaze, tlemma = None, tpos = None):
        i = 0
        pys = []
        while i < self.tx.shape[0]:
            # j = min(self.x.shape[0], i + self.test_batch_size)
            j = i + self.test_batch_size
            s_x, s_m, s_wx, s_cm = tx[i: j], tm[i: j], twx[i: j], tcm[i: j]
            s_gaze = tgaze[i: j] if self.use_gaze else None
            s_lemma = tlemma[i: j] if self.use_lemma else None
            s_pos = tpos[i: j] if self.use_pos else None
            pys.append(self.test_fn(s_x, s_m, s_wx, s_cm, s_gaze, s_lemma, s_pos))
            i = j
        py = np.vstack(tuple(pys))
        if self.use_crf:
            return py.flatten()
        else:
            return py.argmax(axis = 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号