def ComputeCCA(X, Y):
assert X.shape[0] == Y.shape[0], (X.shape, Y.shape, "Unequal number of rows")
assert X.shape[0] > 1, (X.shape, "Must have more than 1 row")
X = NormCenterMatrix(X)
Y = NormCenterMatrix(Y)
X_q, _, _ = decomp_qr.qr(X, overwrite_a=True, mode='economic', pivoting=True)
Y_q, _, _ = decomp_qr.qr(Y, overwrite_a=True, mode='economic', pivoting=True)
C = np.dot(X_q.T, Y_q)
r = linalg.svd(C, full_matrices=False, compute_uv=False)
d = min(X.shape[1], Y.shape[1])
r = r[:d]
r = np.minimum(np.maximum(r, 0.0), 1.0) # remove roundoff errs
return r.mean()
评论列表
文章目录