def broadcast(vec: T.FloatTensor, matrix: T.FloatTensor) -> T.FloatTensor:
"""
Broadcasts vec into the shape of matrix following numpy rules:
vec ~ (N, 1) broadcasts to matrix ~ (N, M)
vec ~ (1, N) and (N,) broadcast to matrix ~ (M, N)
Args:
vec: A vector (either flat, row, or column).
matrix: A matrix (i.e., a 2D tensor).
Returns:
tensor: A tensor of the same size as matrix containing the elements
of the vector.
Raises:
BroadcastError
"""
try:
if ndim(vec) == 1:
if ndim(matrix) == 1:
return vec
return vec.unsqueeze(0).expand(matrix.size(0), matrix.size(1))
else:
return vec.expand(matrix.size(0), matrix.size(1))
except ValueError:
raise BroadcastError('cannot broadcast vector of dimension {} \
onto matrix of dimension {}'.format(shape(vec), shape(matrix)))
评论列表
文章目录