ja_lstm_parser_ph.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:depccg 作者: masashi-y 项目源码 文件源码
def forward(self, ws, cs):
        batchsize, length, max_word_len = cs.shape
        ws = self.emb_word(ws) # (batch, length, word_dim)
        cs = F.reshape(
            F.max_pooling_2d(
                self.conv_char(
                    F.reshape(
                        self.emb_char(cs),
                        (batchsize * length, 1, max_word_len, 50))), (max_word_len, 1)),
                    (batchsize, length, self.char_dim))

        hs = F.transpose(F.concat([ws, cs], 2), (1, 0, 2))
        hs = F.dropout(hs, self.dropout_ratio, train=self.train)
        hs = F.split_axis(hs, length, 0)
        hs_f = []
        hs_b = []
        self._init_state()
        for h_in_f, h_in_b in zip(hs, reversed(hs)):
            h_f = self.lstm_f2(self.lstm_f1(F.reshape(h_in_f, (batchsize, -1))))
            hs_f.append(h_f)
            h_b = self.lstm_b2(self.lstm_b1(F.reshape(h_in_b, (batchsize, -1))))
            hs_b.append(h_b)

        hs = [F.concat([h_f, h_b]) for h_f, h_b in zip(hs_f, reversed(hs_b))]

        cat_ys = [self.linear_cat2(F.dropout(
            F.elu(self.linear_cat1(h)), 0.5, train=self.train)) for h in hs]

        hs = [F.reshape(h, (length, -1)) for h in \
                F.split_axis(F.transpose(F.stack(hs, 2), (0, 2, 1)), batchsize, 0)]

        dep_ys = [self.biaffine(
            F.relu(F.dropout(self.linear_dep(h), 0.32, train=self.train)),
            F.relu(F.dropout(self.linear_head(h), 0.32, train=self.train))) for h in hs]
        return cat_ys, dep_ys
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号