def scalar_broadcast_match(a, b):
""" Returns arguments as np.array, if one is a scalar it will broadcast the other one's shape.
"""
a, b = np.atleast_1d(a, b)
if a.size == 1 and b.size != 1:
a = np.broadcast_to(a, b.shape)
elif b.size == 1 and a.size != 1:
b = np.broadcast_to(b, a.shape)
return a, b
评论列表
文章目录