def backward(ctx, grad_outputs):
size = grad_outputs.size(1)
segm_sorted = torch.sort(ctx.rev_segm_sorted)[1]
grad_outputs = torch.index_select(grad_outputs, 0, segm_sorted)
offset = [ctx.num_zeros]
def backward_segment(l, n):
segment_grad = grad_outputs.narrow(0, offset[0], n // l)
if l > 1:
segment_grad = _MyMax.backward(ctx.maxes[l], segment_grad)[0].view(n, size)
offset[0] += n // l
return segment_grad
segment_grads = [backward_segment(l, n) for l, n in enumerate(ctx.num_lengths) if n > 0]
grads = torch.cat(segment_grads, 0)
rev_length_sorted = torch.sort(ctx.lengths_sorted)[1]
grads = torch.index_select(grads, 0, rev_length_sorted)
return grads, None, None, None
评论列表
文章目录