ja_lstm_parser_ph.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def predict(self, xs):
        """
        batch: list of splitted sentences
        """
        batchsize = len(xs)
        fs = [self.extractor.process(x)[:2] for x in xs]
        ws, cs = concat_examples(fs, padding=IGNORE)
        cat_ys, dep_ys = self.forward(ws, cs)
        cat_ys = F.transpose(F.stack(cat_ys, 2), (0, 2, 1))
        # dep_ys = F.transpose(F.stack(dep_ys, 2), (0, 2, 1))

        cat_ys = [F.log_softmax(
            F.reshape(y, (y.shape[1], -1))[1:len(x) + 1]).data for x, y in \
                zip(xs, F.split_axis(cat_ys, batchsize, 0))]

        dep_ys = [F.log_softmax(y[1:len(x) + 1, :len(x) + 1]).data \
                for x, y in zip(xs, dep_ys)]
        assert len(cat_ys) == len(dep_ys)
        return zip(cat_ys, dep_ys)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号