trace.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def mubs_trace_components(left_matmul_closure, right_matmul_closure, size, num_samples,
                          tensor_cls=torch.Tensor, use_vars=False, dim_num=None):
    r1_coeff = tensor_cls(size)
    torch.arange(0, size, out=r1_coeff)
    r1_coeff.unsqueeze_(1)
    r2_coeff = ((r1_coeff + 1) * (r1_coeff + 2) / 2)

    if dim_num is not None:
        r1 = tensor_cls(num_samples * dim_num).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
        r2 = tensor_cls(num_samples * dim_num).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
    else:
        r1 = tensor_cls(num_samples).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
        r2 = tensor_cls(num_samples).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()

    two_pi_n = (2 * math.pi) / size
    real_comps = torch.cos(two_pi_n * (r1_coeff.matmul(r1) + r2_coeff.matmul(r2))) / math.sqrt(size)
    imag_comps = torch.sin(two_pi_n * (r1_coeff.matmul(r1) + r2_coeff.matmul(r2))) / math.sqrt(size)

    coeff = math.sqrt(size / num_samples)
    comps = torch.cat([real_comps, imag_comps], 1).mul_(coeff)
    if use_vars:
        comps = Variable(comps)
    if dim_num is not None:
        comps = comps.t().contiguous().view(dim_num, 2 * num_samples, size).transpose(1, 2).contiguous()
    left_res = left_matmul_closure(comps)
    right_res = right_matmul_closure(comps)
    return left_res, right_res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号