kronecker_product_lazy_variable.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def _derivative_quadratic_form_factory(self, *args):
        def closure(left_vectors, right_vectors):
            if left_vectors.ndimension() == 1:
                left_factor = left_vectors.unsqueeze(0)
                right_factor = right_vectors.unsqueeze(0)
            else:
                left_factor = left_vectors
                right_factor = right_vectors
            if len(args) == 1:
                columns, = args
                return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
            elif len(args) == 3:
                columns, W_left, W_right = args
                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None
            elif len(args) == 4:
                columns, W_left, W_right, added_diag, = args
                diag_grad = columns.new(len(added_diag)).zero_()
                diag_grad[0] = (left_factor * right_factor).sum()

                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None, diag_grad

        return closure
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号