perplexity.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:contiguous-succotash 作者: kefirski 项目源码 文件源码
def forward(self, logits, target):
        """
        :param logits: tensor with shape of [batch_size, seq_len, input_size]
        :param target: tensor with shape of [batch_size, seq_len] of Long type filled with indexes to gather from logits
        :return: tensor with shape of [batch_size] with perplexity evaluation
        """

        [batch_size, seq_len, input_size] = logits.size()

        logits = logits.view(-1, input_size)
        log_probs = F.log_softmax(logits)
        del logits

        log_probs = log_probs.view(batch_size, seq_len, input_size)
        target = target.unsqueeze(2)

        out = t.gather(log_probs, dim=2, index=target).squeeze(2).neg()

        ppl = out.mean(1).exp()

        return ppl
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号