lstm_parser_bi_fast.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def __call__(self, xs):
        """
        xs [(w,s,p,y), ..., ]
        w: word, c: char, l: length, y: label
        """
        batchsize = len(xs)

        if len(xs[0]) == 6:
            ws, ss, ps, ls, cat_ts, dep_ts = zip(*xs)
            xp = chainer.cuda.get_array_module(ws[0])
            weights = [xp.array(1, 'f') for _ in xs]
        else:
            ws, ss, ps, ls, cat_ts, dep_ts, weights = zip(*xs)

        cat_ys, dep_ys = self.forward(ws, ss, ps, ls, dep_ts if self.train else None)

        cat_loss = reduce(lambda x, y: x + y,
            [we * F.softmax_cross_entropy(y, t) \
                    for y, t, we  in zip(cat_ys, cat_ts, weights)])
        cat_acc = reduce(lambda x, y: x + y,
            [F.accuracy(y, t, ignore_label=IGNORE) for y, t in zip(cat_ys, cat_ts)]) / batchsize

        dep_loss = reduce(lambda x, y: x + y,
            [we * F.softmax_cross_entropy(y, t) \
                    for y, t, we in zip(dep_ys, dep_ts, weights)])
        dep_acc = reduce(lambda x, y: x + y,
            [F.accuracy(y, t, ignore_label=IGNORE) for y, t in zip(dep_ys, dep_ts)]) / batchsize

        chainer.report({
            "tagging_loss": cat_loss,
            "tagging_accuracy": cat_acc,
            "parsing_loss": dep_loss,
            "parsing_accuracy": dep_acc
            }, self)
        return cat_loss + dep_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号