def forward(ctx, pred, labels, is_positive, ohem_ratio, group_size):
n_sample = pred.size()[0]
assert n_sample == len(labels), "mismatch between sample size and label size"
losses = torch.zeros(n_sample)
slopes = torch.zeros(n_sample)
for i in range(n_sample):
losses[i] = max(0, 1 - is_positive * pred[i, labels[i] - 1])
slopes[i] = -is_positive if losses[i] != 0 else 0
losses = losses.view(-1, group_size).contiguous()
sorted_losses, indices = torch.sort(losses, dim=1, descending=True)
keep_num = int(group_size * ohem_ratio)
loss = torch.zeros(1).cuda()
for i in range(losses.size(0)):
loss += sorted_losses[i, :keep_num].sum()
ctx.loss_ind = indices[:, :keep_num]
ctx.labels = labels
ctx.slopes = slopes
ctx.shape = pred.size()
ctx.group_size = group_size
ctx.num_group = losses.size(0)
return loss
评论列表
文章目录