entropy.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def __call__(self,  # type: ignore
                 logits: torch.Tensor,
                 mask: Optional[torch.Tensor] = None):
        """
        Parameters
        ----------
        logits : ``torch.Tensor``, required.
            A tensor of unnormalized log probabilities of shape (batch_size, ..., num_classes).
        mask: ``torch.Tensor``, optional (default = None).
            A masking tensor of shape (batch_size, ...).
        """
        # Get the data from the Variables.
        logits, mask = self.unwrap_to_tensors(logits, mask)

        if mask is None:
            mask = torch.ones(logits.size()[:-1])

        log_probs = torch.nn.functional.log_softmax(Variable(logits), dim=-1).data
        probabilities = torch.exp(log_probs) * mask.unsqueeze(-1)
        weighted_negative_likelihood = - log_probs * probabilities
        entropy = weighted_negative_likelihood.sum(-1)

        self._entropy += entropy.sum() / mask.sum()
        self._count += 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号