def diag(self):
"""
Gets the diagonal of the Kronecker Product matrix wrapped by this object.
"""
if len(self.J_lefts[0]) != len(self.J_rights[0]):
raise RuntimeError('diag not supported for non-square interpolated Toeplitz matrices.')
d, n_data, n_interp = self.J_lefts.size()
n_grid = len(self.columns[0])
left_interps_values = self.C_lefts.unsqueeze(3)
right_interps_values = self.C_rights.unsqueeze(2)
interps_values = torch.matmul(left_interps_values, right_interps_values)
left_interps_indices = self.J_lefts.unsqueeze(3).expand(d, n_data, n_interp, n_interp)
right_interps_indices = self.J_rights.unsqueeze(2).expand(d, n_data, n_interp, n_interp)
toeplitz_indices = (left_interps_indices - right_interps_indices).fmod(n_grid).abs().long()
toeplitz_vals = Variable(self.columns.data.new(d, n_data * n_interp * n_interp).zero_())
mask = self.columns.data.new(d, n_data * n_interp * n_interp).zero_()
for i in range(d):
mask[i] += torch.ones(n_data * n_interp * n_interp)
temp = self.columns.index_select(1, Variable(toeplitz_indices.view(d, -1)[i]))
toeplitz_vals += Variable(mask) * temp.view(toeplitz_indices.size())
mask[i] -= torch.ones(n_data * n_interp * n_interp)
diag = (Variable(interps_values) * toeplitz_vals).sum(3).sum(2)
diag = diag.prod(0)
if self.added_diag is not None:
diag += self.added_diag
return diag
kronecker_product_lazy_variable.py 文件源码
python
阅读 47
收藏 0
点赞 0
评论 0
评论列表
文章目录