aggregate.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def assign(R, S):
    r"""
    Calculate assignment weights for given residuals (:math:`R`) and scale (:math:`S`)

    .. math::
        a_{ik} = \frac{exp(-s_k\|r_{ik}\|^2)}{\sum_{j=1}^K exp(-s_j\|r_{ik}\|^2)}

    Shape:
        - Input: :math:`R\in\mathcal{R}^{B\times N\times K\times D}` :math:`S\in \mathcal{R}^K` (where :math:`B` is batch, :math:`N` is total number of features, :math:`K` is number is codewords, :math:`D` is feature dimensions.)
        - Output :math:`A\in\mathcal{R}^{B\times N\times K}`

    """
    L = square_squeeze()(R)
    K = S.size(0)
    SL = L * S.view(1,1,K)
    return F.softmax(SL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号