def Softmax(x):
# TODO make this numerically stable by shifting by max?
exp = Exp(x)
if len(x.shape) == 1: # workaround because numpy einsum can't broadcast? https://github.com/numpy/numpy/issues/9984
return exp / Einsum("i->", exp)
elif len(x.shape) == 2:
return exp / Einsum("bi,o->bo", exp, np.array([1]))
elif len(x.shape) == 3:
return exp / Einsum("abi,o->abo", exp, np.array([1]))
elif len(x.shape) == 4:
return exp / Einsum("abci,o->abco", exp, np.array([1]))
else:
raise ValueError("5D tensors not yet supported")
评论列表
文章目录