model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def _score_sentence(self, input, tags):
        bsz, sent_len, l_size = input.size()
        score = Variable(self.torch.FloatTensor(bsz).fill_(0.))
        s_score = Variable(self.torch.LongTensor([[START]]*bsz))

        tags = torch.cat([s_score, tags], dim=-1)
        input_t = input.transpose(0, 1)

        for i, words in enumerate(input_t):
            temp = self.transitions.index_select(1, tags[:, i])
            bsz_t = gather_index(temp.transpose(0, 1), tags[:, i + 1])
            w_step_score = gather_index(words, tags[:, i+1])
            score = score + bsz_t + w_step_score

        temp = self.transitions.index_select(1, tags[:, -1])
        bsz_t = gather_index(temp.transpose(0, 1),
                    Variable(self.torch.LongTensor([STOP]*bsz)))
        return score+bsz_t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号