def pca_fit(X, n_components): mean = tf.reduce_mean(X, axis=0) centered_X = X - mean S, U, V = tf.svd(centered_X) return V[:n_components], mean