multibox_loss.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:DSOD-Pytorch-Implementation 作者: Ellinier 项目源码 文件源码
def cross_entropy_loss(self, x, y):
        '''Cross entropy loss w/o averaging across all samples.

        Args:
          x: (tensor) sized [N,D].
          y: (tensor) sized [N,].

        Return:
          (tensor) cross entroy loss, sized [N,].
        '''
        # print(x.size()) # [8732, 16]
        xmax = x.data.max()
        # print(x.data.size()) # [8732, 16]
        # print(xmax.size()) # max--float object
        log_sum_exp = torch.log(torch.sum(torch.exp(x-xmax), 1)) + xmax
        # print(log_sum_exp.size()) # [8732,]
        # print(x.gather(1, y.view(-1,1)).size()) # [8732, 1]
        # print((log_sum_exp.view(-1, 1) - x.gather(1, y.view(-1,1))).size())
        return log_sum_exp.view(-1, 1) - x.gather(1, y.view(-1,1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号