CCA & Permutation Testing
Overview
Canonical Correlation Analysis (CCA) with permutation testing is a standard protocol for evaluating the relationship between multimodal brain data representations. This recipe provides a reproducible approach for testing whether learned representations capture meaningful biological signal beyond chance.
Background
Canonical Correlation Analysis (CCA) identifies linear combinations of two sets of variables that maximize their correlation. In the context of foundation models:
- Use Case 1: Compare model embeddings from different modalities (e.g., fMRI vs genetics)
- Use Case 2: Test if model representations align with behavioral or clinical outcomes
- Use Case 3: Validate cross-modal alignment in multi-modal foundation models
Permutation Testing provides statistical significance by: 1. Computing CCA on real data 2. Shuffling one modality randomly N times 3. Computing CCA on each permutation 4. Comparing real correlation against null distribution
Protocol
1. Data Preparation
import numpy as np
from sklearn.cross_decomposition import CCA
# Load your model embeddings
# X: (n_samples, n_features_1) - e.g., fMRI embeddings
# Y: (n_samples, n_features_2) - e.g., genomic embeddings
X = model_fmri.encode(fmri_data) # Shape: (N, D1)
Y = model_genomics.encode(genomic_data) # Shape: (N, D2)
# Standardize
X = (X - X.mean(axis=0)) / X.std(axis=0)
Y = (Y - Y.mean(axis=0)) / Y.std(axis=0)
2. Run CCA
# Number of canonical components to compute
n_components = min(X.shape[1], Y.shape[1], 20)
cca = CCA(n_components=n_components)
X_c, Y_c = cca.fit_transform(X, Y)
# Compute canonical correlations
canonical_corrs = [np.corrcoef(X_c[:, i], Y_c[:, i])[0, 1]
for i in range(n_components)]
print(f"Top 5 canonical correlations: {canonical_corrs[:5]}")
3. Permutation Test
from tqdm import tqdm
n_permutations = 1000
perm_corrs = []
for _ in tqdm(range(n_permutations)):
# Shuffle one modality
idx = np.random.permutation(len(Y))
Y_perm = Y[idx]
# Run CCA on permuted data
cca_perm = CCA(n_components=n_components)
X_c_perm, Y_c_perm = cca_perm.fit_transform(X, Y_perm)
# Store top canonical correlation
perm_corr = np.corrcoef(X_c_perm[:, 0], Y_c_perm[:, 0])[0, 1]
perm_corrs.append(perm_corr)
perm_corrs = np.array(perm_corrs)
4. Compute P-Value
# P-value: fraction of permutations with correlation >= observed
real_corr = canonical_corrs[0]
p_value = (perm_corrs >= real_corr).mean()
print(f"Real CCA correlation: {real_corr:.4f}")
print(f"Mean permuted correlation: {perm_corrs.mean():.4f}")
print(f"P-value: {p_value:.4f}")
# Effect size
effect_size = (real_corr - perm_corrs.mean()) / perm_corrs.std()
print(f"Effect size (Z-score): {effect_size:.2f}")
5. Visualization
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Histogram of permutation null distribution
axes[0].hist(perm_corrs, bins=50, alpha=0.7, label='Permuted')
axes[0].axvline(real_corr, color='red', linestyle='--',
label=f'Observed (p={p_value:.3f})')
axes[0].set_xlabel('Canonical Correlation')
axes[0].set_ylabel('Frequency')
axes[0].set_title('CCA Permutation Test')
axes[0].legend()
# Scree plot of canonical correlations
axes[1].plot(range(1, len(canonical_corrs)+1), canonical_corrs,
marker='o', label='Real Data')
axes[1].axhline(perm_corrs.mean(), color='gray', linestyle='--',
label='Permuted Mean')
axes[1].set_xlabel('Canonical Component')
axes[1].set_ylabel('Correlation')
axes[1].set_title('Canonical Correlations (Scree Plot)')
axes[1].legend()
plt.tight_layout()
plt.savefig('cca_permutation_results.png', dpi=150)
Interpretation Guidelines
Statistical Significance
- p < 0.05: Significant alignment between modalities
- p < 0.01: Strong evidence of cross-modal relationship
- p < 0.001: Very strong evidence
Effect Size
- Z > 2: Small to moderate effect
- Z > 3: Moderate to large effect
- Z > 5: Large effect
Multiple Comparisons
When testing multiple canonical components, apply correction:
from statsmodels.stats.multitest import multipletests
# Compute p-value for each component
p_values = [(perm_corrs >= corr).mean() for corr in canonical_corrs]
# Bonferroni correction
_, p_corrected, _, _ = multipletests(p_values, method='bonferroni')
Example Use Cases
Use Case 1: Multi-Modal Foundation Model Validation
Test if a multi-modal brain model's fMRI and genetics branches learn aligned representations:
# Extract embeddings from each branch
fmri_emb = model.encode_fmri(fmri_data)
gene_emb = model.encode_genetics(genetics_data)
# Run CCA + permutation test
# → If p < 0.05, model successfully learns cross-modal alignment
Use Case 2: Behavioral Prediction Alignment
Test if brain embeddings correlate with behavioral scores:
# Brain embeddings (N x D)
brain_emb = model.encode(fmri_data)
# Behavioral scores (N x K)
behavior_scores = load_behavioral_data()
# Run CCA
# → High correlation suggests embeddings capture behaviorally relevant features
ITU AI4H Alignment
This protocol aligns with:
- DEL3 Section 5.2: Validation and verification requirements
- DEL10.8: Neurology-specific evaluation protocols
- DEL0.1: Statistical significance standards for AI4H benchmarks
References
- Hotelling, H. (1936). Relations between two sets of variates. Biometrika, 28(3/4), 321-377.
- Witten, D. M., Tibshirani, R., & Hastie, T. (2009). A penalized matrix decomposition. Biostatistics, 10(3), 515-534.
- Wang, W., Yan, X., Lee, H., & Livescu, K. (2016). Deep Variational Canonical Correlation Analysis. arXiv preprint arXiv:1610.03454.