pytorch_model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:biaffineparser 作者: chantera 项目源码 文件源码
def extract_best_label_logits(self, arc_logits, label_logits, lengths):
        pred_arcs = torch.squeeze(
            torch.max(arc_logits, dim=1)[1], dim=1).data.cpu().numpy()
        size = label_logits.size()
        output_logits = _model_var(
            self.model,
            torch.zeros(size[0], size[1], size[3]))
        for batch_index, (_logits, _arcs, _length) \
                in enumerate(zip(label_logits, pred_arcs, lengths)):
            for i in range(_length):
                output_logits[batch_index] = _logits[_arcs[i]]
        return output_logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号