__init__.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def left_interp(interp_indices, interp_values, rhs):
    is_vector = rhs.ndimension() == 1

    if is_vector:
        res = rhs.index_select(0, interp_indices.view(-1)).view(*interp_values.size())
        res = res.mul(interp_values)
        return res.sum(-1)

    else:
        interp_size = list(interp_indices.size()) + [rhs.size(-1)]
        rhs_size = deepcopy(interp_size)
        rhs_size[-3] = rhs.size()[-2]
        interp_indices_expanded = interp_indices.unsqueeze(-1).expand(*interp_size)
        res = rhs.unsqueeze(-2).expand(*rhs_size).gather(-3, interp_indices_expanded)
        res = res.mul(interp_values.unsqueeze(-1).expand(interp_size))
        return res.sum(-2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号