model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, input):
        bsz, sent_len, l_size = input.size()
        init_alphas = self.torch.FloatTensor(bsz, self.label_size).fill_(-10000.)
        init_alphas[:, START].fill_(0.)
        forward_var = Variable(init_alphas)

        input_t = input.transpose(0, 1)
        for words in input_t:
            alphas_t = []
            for next_tag in range(self.label_size):
                emit_score = words[:, next_tag].contiguous()
                emit_score = emit_score.unsqueeze(1).expand_as(words)

                trans_score = self.transitions[next_tag, :].view(1, -1).expand_as(words)
                next_tag_var = forward_var + trans_score + emit_score
                alphas_t.append(log_sum_exp(next_tag_var, True))
            forward_var = torch.cat(alphas_t, dim=-1)

        return log_sum_exp(forward_var)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号