ja_lstm_parser_bi.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def forward(self, ws, cs, ls, dep_ts=None):
        batchsize = len(ws)
        xp = chainer.cuda.get_array_module(ws[0])
        ws = map(self.emb_word, ws)
        cs = [F.squeeze(
            F.max_pooling_2d(
                self.conv_char(
                    F.expand_dims(
                        self.emb_char(c), 1)), (int(l[0]), 1)))
                    for c, l in zip(cs, ls)]
        xs_f = [F.dropout(F.concat([w, c]),
            self.dropout_ratio, train=self.train) for w, c in zip(ws, cs)]
        xs_b = [x[::-1] for x in xs_f]
        cx_f, hx_f, cx_b, hx_b = self._init_state(xp, batchsize)
        _, _, hs_f = self.lstm_f(hx_f, cx_f, xs_f, train=self.train)
        _, _, hs_b = self.lstm_b(hx_b, cx_b, xs_b, train=self.train)
        hs_b = [x[::-1] for x in hs_b]
        hs = [F.concat([h_f, h_b]) for h_f, h_b in zip(hs_f, hs_b)]


        dep_ys = [self.biaffine_arc(
            F.elu(F.dropout(self.arc_dep(h), 0.32, train=self.train)),
            F.elu(F.dropout(self.arc_head(h), 0.32, train=self.train))) for h in hs]

        if dep_ts is not None:
            heads = dep_ts
        else:
            heads = [F.argmax(y, axis=1) for y in dep_ys]

        cat_ys = [
                self.biaffine_tag(
            F.elu(F.dropout(self.rel_dep(h), 0.32, train=self.train)),
            F.elu(F.dropout(self.rel_head(
                F.embed_id(t, h, ignore_label=IGNORE)), 0.32, train=self.train))) \
                        for h, t in zip(hs, heads)]

        return cat_ys, dep_ys
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号