ja_lstm_parser_ph.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def __call__(self, ws, cs, cat_ts, dep_ts):
        batchsize, length = cat_ts.shape
        cat_ys, dep_ys = self.forward(ws, cs)
        cat_ys = cat_ys[1:-1]
        cat_ts = [F.reshape(x, (batchsize,)) for x \
                in F.split_axis(F.transpose(cat_ts), length, 0)]
        assert len(cat_ys) == len(cat_ts)
        cat_loss = reduce(lambda x, y: x + y,
            [F.softmax_cross_entropy(y, t) for y, t in zip(cat_ys, cat_ts)])
        cat_acc = reduce(lambda x, y: x + y,
            [F.accuracy(y, t, ignore_label=IGNORE) for y, t in zip(cat_ys, cat_ts)])


        # hs [(length, hidden_dim), ...]
        dep_ys = [x[1:-1] for x in dep_ys]
        dep_ts = [F.reshape(x, (length,)) for x in F.split_axis(dep_ts, batchsize, 0)]

        dep_loss = reduce(lambda x, y: x + y,
            [F.softmax_cross_entropy(y, t) for y, t in zip(dep_ys, dep_ts)])
        dep_acc = reduce(lambda x, y: x + y,
            [F.accuracy(y, t, ignore_label=IGNORE) for y, t in zip(dep_ys, dep_ts)])

        cat_acc /= length
        dep_acc /= batchsize
        chainer.report({
            "tagging_loss": cat_loss,
            "tagging_accuracy": cat_acc,
            "parsing_loss": dep_loss,
            "parsing_accuracy": dep_acc
            }, self)
        return cat_loss + dep_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号