def scatter_(mat: T.Tensor, inds: T.LongTensor, val: T.Scalar) -> T.Tensor:
"""
Assign a value a specific points in a matrix.
Iterates along the rows of mat,
successively assigning val to column indices given by inds.
Note:
Modifies mat in place.
Args:
mat: A tensor.
inds: The indices
val: The value to insert
"""
return mat.scatter_(1, inds.unsqueeze(1), val)
评论列表
文章目录