def symbolic_kernel(self, X):
if self.kernel_type == 'linear':
K = self.alpha * torch.dot(X, self.X_kernel.transpose(0, 1)) + self.c
elif self.kernel_type == 'poly':
K = (self.alpha * torch.dot(X, self.X_kernel.transpose(0, 1)) + self.c) ** self.degree
elif self.kernel_type == 'rbf':
D = sym_distance_matrix(X, self.X_kernel, self_similarity=False)
K = torch.exp(-D ** 2 / (self.sigma_kernel ** 2))
else:
raise Exception('Unknown kernel type: ', self.kernel_type)
return K
评论列表
文章目录