def rcumsum(input, dim=0):
"""
Computes a reverse cumulative sum
Args:
- input: tensor
- dim: dimension to reverse on
Returns:
- rcumsum on input
"""
reverse_index = torch.LongTensor(list(range(input.size(dim))[::-1]))
return torch.index_select(input, dim, reverse_index).cumsum(dim).index_select(dim, reverse_index)
评论列表
文章目录