CRFv2.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:BiLSTM-CCM 作者: codedecde 项目源码 文件源码
def _viterbi_decode(self, feats):
        backpointers = []       
        init_alphas = torch.Tensor(self.tagset_size, 1).fill_(0.).type(self.dtype)      
        forward_var = autograd.Variable(init_alphas).type(self.dtype)       
        for ix,feat in enumerate(feats):
            if ix == 0:
                forward_var += feat.view(self.tagset_size, 1) + self.initial_weights
            else:               
                viterbi_vars, viterbi_idx = torch.max(self.transitions + torch.transpose( forward_var.repeat(1, self.tagset_size), 0 ,1), 1)
                forward_var = feat.view(self.tagset_size,1) + viterbi_vars
                backpointers.append(viterbi_idx)
        terminal_var = forward_var + self.final_weights                     
        _ , best_tag_id = torch.max(terminal_var,0)
        best_tag_id = to_scalar(best_tag_id)
        path_score = terminal_var[best_tag_id]
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = to_scalar(bptrs_t[best_tag_id])
            best_path.append(best_tag_id)           
        best_path.reverse()     
        return path_score, best_path
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号