def mubs_trace_components(left_matmul_closure, right_matmul_closure, size, num_samples,
tensor_cls=torch.Tensor, use_vars=False, dim_num=None):
r1_coeff = tensor_cls(size)
torch.arange(0, size, out=r1_coeff)
r1_coeff.unsqueeze_(1)
r2_coeff = ((r1_coeff + 1) * (r1_coeff + 2) / 2)
if dim_num is not None:
r1 = tensor_cls(num_samples * dim_num).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
r2 = tensor_cls(num_samples * dim_num).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
else:
r1 = tensor_cls(num_samples).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
r2 = tensor_cls(num_samples).uniform_().mul_(size).floor().type_as(r1_coeff).unsqueeze(1).t()
two_pi_n = (2 * math.pi) / size
real_comps = torch.cos(two_pi_n * (r1_coeff.matmul(r1) + r2_coeff.matmul(r2))) / math.sqrt(size)
imag_comps = torch.sin(two_pi_n * (r1_coeff.matmul(r1) + r2_coeff.matmul(r2))) / math.sqrt(size)
coeff = math.sqrt(size / num_samples)
comps = torch.cat([real_comps, imag_comps], 1).mul_(coeff)
if use_vars:
comps = Variable(comps)
if dim_num is not None:
comps = comps.t().contiguous().view(dim_num, 2 * num_samples, size).transpose(1, 2).contiguous()
left_res = left_matmul_closure(comps)
right_res = right_matmul_closure(comps)
return left_res, right_res
评论列表
文章目录