kronecker_product_lazy_variable.py 文件源码

python
阅读 47 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def diag(self):
        """
        Gets the diagonal of the Kronecker Product matrix wrapped by this object.
        """
        if len(self.J_lefts[0]) != len(self.J_rights[0]):
            raise RuntimeError('diag not supported for non-square interpolated Toeplitz matrices.')
        d, n_data, n_interp = self.J_lefts.size()
        n_grid = len(self.columns[0])

        left_interps_values = self.C_lefts.unsqueeze(3)
        right_interps_values = self.C_rights.unsqueeze(2)
        interps_values = torch.matmul(left_interps_values, right_interps_values)

        left_interps_indices = self.J_lefts.unsqueeze(3).expand(d, n_data, n_interp, n_interp)
        right_interps_indices = self.J_rights.unsqueeze(2).expand(d, n_data, n_interp, n_interp)

        toeplitz_indices = (left_interps_indices - right_interps_indices).fmod(n_grid).abs().long()
        toeplitz_vals = Variable(self.columns.data.new(d, n_data * n_interp * n_interp).zero_())

        mask = self.columns.data.new(d, n_data * n_interp * n_interp).zero_()
        for i in range(d):
            mask[i] += torch.ones(n_data * n_interp * n_interp)
            temp = self.columns.index_select(1, Variable(toeplitz_indices.view(d, -1)[i]))
            toeplitz_vals += Variable(mask) * temp.view(toeplitz_indices.size())
            mask[i] -= torch.ones(n_data * n_interp * n_interp)

        diag = (Variable(interps_values) * toeplitz_vals).sum(3).sum(2)
        diag = diag.prod(0)

        if self.added_diag is not None:
            diag += self.added_diag

        return diag
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号