baseModel.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def forward(self, s1, l1, s2, l2):
        p_s1 = self.Embd(s1)
        p_s2 = self.Embd(s2)

        s1_a_out = torch_util.auto_rnn_bilstm(self.lstm, p_s1, l1)
        s2_a_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2)

        s1_max_out = torch_util.max_along_time(s1_a_out, l1)
        s2_max_out = torch_util.max_along_time(s2_a_out, l2)

        features = torch.cat([s1_max_out, s2_max_out, torch.abs(s1_max_out - s2_max_out), s1_max_out * s2_max_out], dim=1)

        out = self.classifier(features)
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号