ops.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:autodiff 作者: bgavran 项目源码 文件源码
def Softmax(x):
    # TODO make this numerically stable by shifting by max?
    exp = Exp(x)
    if len(x.shape) == 1:  # workaround because numpy einsum can't broadcast? https://github.com/numpy/numpy/issues/9984
        return exp / Einsum("i->", exp)
    elif len(x.shape) == 2:
        return exp / Einsum("bi,o->bo", exp, np.array([1]))
    elif len(x.shape) == 3:
        return exp / Einsum("abi,o->abo", exp, np.array([1]))
    elif len(x.shape) == 4:
        return exp / Einsum("abci,o->abco", exp, np.array([1]))
    else:
        raise ValueError("5D tensors not yet supported")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号