def cross_entropy_loss(self, x, y):
'''Cross entropy loss w/o averaging across all samples.
Args:
x: (tensor) sized [N,D].
y: (tensor) sized [N,].
Return:
(tensor) cross entroy loss, sized [N,].
'''
# print(x.size()) # [8732, 16]
xmax = x.data.max()
# print(x.data.size()) # [8732, 16]
# print(xmax.size()) # max--float object
log_sum_exp = torch.log(torch.sum(torch.exp(x-xmax), 1)) + xmax
# print(log_sum_exp.size()) # [8732,]
# print(x.gather(1, y.view(-1,1)).size()) # [8732, 1]
# print((log_sum_exp.view(-1, 1) - x.gather(1, y.view(-1,1))).size())
return log_sum_exp.view(-1, 1) - x.gather(1, y.view(-1,1))
multibox_loss.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录