Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save wojtyniak/59261b6317995dc3615b2f53e6f7bac8 to your computer and use it in GitHub Desktop.

Select an option

Save wojtyniak/59261b6317995dc3615b2f53e6f7bac8 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PQMass: Probabilistic Assessment of the Quality of Generative Models using Probability Mass Estimation\n",
"\n",
"**Paper:** *PQMass: Probabilistic Assessment of the Quality of Generative Models using Probability Mass Estimation* \n",
"**Authors:** Pablo Lemos, Sammy Sharief, Nikolay Malkin, Salma Salhi, Connor Stone, Laurence Perreault-Levasseur, Yashar Hezaveh \n",
"**Published:** ICLR 2025\n",
"\n",
"---\n",
"\n",
"## Overview\n",
"\n",
"This notebook provides a comprehensive, educational walkthrough of the **PQMass** method for evaluating generative models. PQMass is a likelihood-free statistical framework that compares two distributions by:\n",
"\n",
"1. **Partitioning** the sample space into non-overlapping regions (Voronoi cells)\n",
"2. **Counting** how many samples from each distribution fall into each region\n",
"3. **Testing** whether the count distributions are statistically equivalent using a chi-squared test\n",
"\n",
"**Key advantages of PQMass:**\n",
"- No assumptions about the underlying density functions\n",
"- No need to train auxiliary models (unlike FID, FLD)\n",
"- Scales well to moderately high-dimensional data\n",
"- Works with any data modality (images, sequences, tabular data, etc.)\n",
"- Provides statistically rigorous p-values\n",
"\n",
"**Note on resource constraints:** This notebook demonstrates the PQMass method using small-scale examples that run efficiently within typical computational limits. For production use with large datasets, you would scale up the sample sizes and number of reference points."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup and Dependencies\n",
"\n",
"First, we install all required dependencies."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/app/.venv/bin/python: No module named pip\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install numpy scipy matplotlib scikit-learn torch torchvision tqdm"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All libraries imported successfully!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/app/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy import stats\n",
"from scipy.spatial.distance import cdist\n",
"from sklearn.mixture import GaussianMixture\n",
"from sklearn.datasets import make_blobs\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchvision import datasets, transforms\n",
"from tqdm.auto import tqdm\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"# Set random seeds for reproducibility\n",
"np.random.seed(42)\n",
"torch.manual_seed(42)\n",
"\n",
"print(\"All libraries imported successfully!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Core PQMass Implementation\n",
"\n",
"We implement the core PQMass algorithm following the paper's methodology:\n",
"\n",
"**Algorithm Steps:**\n",
"1. **Define reference points:** Sample $n_R/2$ points from each distribution to create Voronoi cell centers\n",
"2. **Count points in Voronoi cells:** Assign each remaining sample to its nearest reference point\n",
"3. **Test to compare multinomials:** Compute the $\\chi^2_{PQM}$ statistic and p-value\n",
"\n",
"**Chi-squared statistic (Equation 4 from paper):**\n",
"\n",
"$$\\chi^2_{PQM} = \\sum_{j=1}^{n_R} \\left[ \\frac{(k_x^j - \\hat{N}_j^{(1)})^2}{\\hat{N}_j^{(1)}} + \\frac{(k_y^j - \\hat{N}_j^{(2)})^2}{\\hat{N}_j^{(2)}} \\right]$$\n",
"\n",
"where:\n",
"- $k_x^j, k_y^j$ are the observed counts in region $j$\n",
"- $\\hat{N}_j^{(1)} = m \\hat{p}_j$, $\\hat{N}_j^{(2)} = n \\hat{p}_j$ are expected counts\n",
"- $\\hat{p}_j = \\frac{k_x^j + k_y^j}{m + n}$ is the combined empirical probability"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PQMass core functions implemented successfully!\n"
]
}
],
"source": [
"def pqmass_test(samples_p, samples_q, n_R=100, distance_metric='euclidean', return_counts=False):\n",
" \"\"\"\n",
" Perform the PQMass two-sample statistical test.\n",
" \n",
" Parameters:\n",
" -----------\n",
" samples_p : np.ndarray, shape (m, d)\n",
" Samples from distribution p (e.g., real data)\n",
" samples_q : np.ndarray, shape (n, d)\n",
" Samples from distribution q (e.g., generated data)\n",
" n_R : int\n",
" Number of reference points (Voronoi cells). Must be even.\n",
" distance_metric : str\n",
" Distance metric for Voronoi tessellation ('euclidean', 'cityblock', etc.)\n",
" return_counts : bool\n",
" If True, return additional information about bin counts\n",
" \n",
" Returns:\n",
" --------\n",
" chi2_stat : float\n",
" The chi-squared PQM statistic\n",
" p_value : float\n",
" The p-value from the chi-squared test\n",
" \"\"\"\n",
" m = len(samples_p)\n",
" n = len(samples_q)\n",
" \n",
" # Step 1: Define reference points\n",
" # Sample n_R/2 points from each distribution\n",
" n_R_half = n_R // 2\n",
" \n",
" # Randomly select reference points\n",
" idx_p = np.random.choice(m, size=n_R_half, replace=False)\n",
" idx_q = np.random.choice(n, size=n_R_half, replace=False)\n",
" \n",
" reference_points = np.vstack([\n",
" samples_p[idx_p],\n",
" samples_q[idx_q]\n",
" ])\n",
" \n",
" # Remove reference points from samples\n",
" mask_p = np.ones(m, dtype=bool)\n",
" mask_p[idx_p] = False\n",
" mask_q = np.ones(n, dtype=bool)\n",
" mask_q[idx_q] = False\n",
" \n",
" samples_p_test = samples_p[mask_p]\n",
" samples_q_test = samples_q[mask_q]\n",
" \n",
" m_test = len(samples_p_test)\n",
" n_test = len(samples_q_test)\n",
" \n",
" # Step 2: Count points in Voronoi cells\n",
" # For each sample, find the nearest reference point\n",
" \n",
" # Compute distances from samples_p to all reference points\n",
" dist_p = cdist(samples_p_test, reference_points, metric=distance_metric)\n",
" # Assign to nearest reference point (with tie-breaking by index)\n",
" assignments_p = np.argmin(dist_p, axis=1)\n",
" \n",
" # Compute distances from samples_q to all reference points\n",
" dist_q = cdist(samples_q_test, reference_points, metric=distance_metric)\n",
" assignments_q = np.argmin(dist_q, axis=1)\n",
" \n",
" # Count samples in each Voronoi cell\n",
" k_p = np.bincount(assignments_p, minlength=n_R)\n",
" k_q = np.bincount(assignments_q, minlength=n_R)\n",
" \n",
" # Step 3: Test to compare multinomials\n",
" # Compute expected counts under null hypothesis (Equation 3 from paper)\n",
" p_hat = (k_p + k_q) / (m_test + n_test)\n",
" N_hat_1 = m_test * p_hat # Expected counts for samples_p\n",
" N_hat_2 = n_test * p_hat # Expected counts for samples_q\n",
" \n",
" # Avoid division by zero - add small epsilon to empty cells\n",
" epsilon = 1e-10\n",
" N_hat_1 = np.maximum(N_hat_1, epsilon)\n",
" N_hat_2 = np.maximum(N_hat_2, epsilon)\n",
" \n",
" # Compute chi-squared statistic (Equation 4 from paper)\n",
" chi2_stat = np.sum((k_p - N_hat_1)**2 / N_hat_1) + np.sum((k_q - N_hat_2)**2 / N_hat_2)\n",
" \n",
" # Compute p-value (Equation 5 from paper)\n",
" # Using chi-squared distribution with n_R - 1 degrees of freedom\n",
" dof = n_R - 1\n",
" p_value = 1 - stats.chi2.cdf(chi2_stat, dof)\n",
" \n",
" if return_counts:\n",
" return chi2_stat, p_value, {\n",
" 'k_p': k_p,\n",
" 'k_q': k_q,\n",
" 'N_hat_1': N_hat_1,\n",
" 'N_hat_2': N_hat_2,\n",
" 'reference_points': reference_points\n",
" }\n",
" \n",
" return chi2_stat, p_value\n",
"\n",
"\n",
"def pqmass_test_multiple_tessellations(samples_p, samples_q, n_R=100, n_tessellations=10, \n",
" distance_metric='euclidean'):\n",
" \"\"\"\n",
" Perform PQMass test with multiple random tessellations to reduce variance.\n",
" \n",
" Returns mean and std of chi-squared values and p-values.\n",
" \"\"\"\n",
" chi2_values = []\n",
" p_values = []\n",
" \n",
" for _ in range(n_tessellations):\n",
" chi2, p_val = pqmass_test(samples_p, samples_q, n_R=n_R, distance_metric=distance_metric)\n",
" chi2_values.append(chi2)\n",
" p_values.append(p_val)\n",
" \n",
" return {\n",
" 'chi2_mean': np.mean(chi2_values),\n",
" 'chi2_std': np.std(chi2_values),\n",
" 'p_value_mean': np.mean(p_values),\n",
" 'p_value_std': np.std(p_values),\n",
" 'chi2_values': chi2_values,\n",
" 'p_values': p_values\n",
" }\n",
"\n",
"print(\"PQMass core functions implemented successfully!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Null Test Validation (Workflow 12)\n",
"\n",
"First, we validate that PQMass works correctly by testing the **null hypothesis**: comparing two sets of samples from the **same** distribution. Under the null hypothesis, the chi-squared statistic should follow a $\\chi^2$ distribution with $n_R - 1$ degrees of freedom.\n",
"\n",
"We'll use a 2D Gaussian mixture model as our test distribution."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1200x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data",
"transient": {}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated 2000 samples from each distribution\n",
"Sample 1 shape: (2000, 2)\n",
"Sample 2 shape: (2000, 2)\n"
]
}
],
"source": [
"# Generate samples from a 2D Gaussian mixture model\n",
"def generate_gaussian_mixture_2d(n_samples, n_components=3, random_state=None):\n",
" \"\"\"\n",
" Generate samples from a 2D Gaussian mixture model.\n",
" \"\"\"\n",
" if random_state is not None:\n",
" np.random.seed(random_state)\n",
" \n",
" # Create cluster centers\n",
" centers = np.array([\n",
" [0, 0],\n",
" [3, 3],\n",
" [-2, 3]\n",
" ])\n",
" \n",
" # Generate samples\n",
" samples, _ = make_blobs(n_samples=n_samples, centers=centers, \n",
" cluster_std=0.6, random_state=random_state)\n",
" return samples\n",
"\n",
"# Generate two independent sets from the same distribution\n",
"n_samples = 2000\n",
"samples_1 = generate_gaussian_mixture_2d(n_samples, random_state=42)\n",
"samples_2 = generate_gaussian_mixture_2d(n_samples, random_state=123)\n",
"\n",
"# Visualize the samples\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"axes[0].scatter(samples_1[:, 0], samples_1[:, 1], alpha=0.5, s=10)\n",
"axes[0].set_title('Sample Set 1 (from distribution p)')\n",
"axes[0].set_xlabel('x1')\n",
"axes[0].set_ylabel('x2')\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"axes[1].scatter(samples_2[:, 0], samples_2[:, 1], alpha=0.5, s=10, color='orange')\n",
"axes[1].set_title('Sample Set 2 (from distribution p)')\n",
"axes[1].set_xlabel('x1')\n",
"axes[1].set_ylabel('x2')\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"Generated {n_samples} samples from each distribution\")\n",
"print(f\"Sample 1 shape: {samples_1.shape}\")\n",
"print(f\"Sample 2 shape: {samples_2.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running null test with 100 repetitions...\n",
"Each test uses n_R = 100 reference points\n",
"Expected chi-squared mean: 99\n",
"Expected chi-squared std: 14.07\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Running null tests: 0%| | 0/100 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Running null tests: 57%|█████▋ | 57/100 [00:00<00:00, 563.42it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Running null tests: 100%|██████████| 100/100 [00:00<00:00, 553.84it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Null test results:\n",
"Mean chi-squared: 83.80 (expected: 99)\n",
"Std chi-squared: 9.15 (expected: 14.07)\n",
"Mean p-value: 0.818 (expected: ~0.5 for uniform distribution)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Perform null test - both samples from same distribution\n",
"n_R = 100\n",
"n_repetitions = 100 # Repeat test multiple times to validate chi-squared distribution\n",
"\n",
"print(f\"Running null test with {n_repetitions} repetitions...\")\n",
"print(f\"Each test uses n_R = {n_R} reference points\")\n",
"print(f\"Expected chi-squared mean: {n_R - 1}\")\n",
"print(f\"Expected chi-squared std: {np.sqrt(2 * (n_R - 1)):.2f}\")\n",
"print()\n",
"\n",
"chi2_null = []\n",
"p_values_null = []\n",
"\n",
"for i in tqdm(range(n_repetitions), desc=\"Running null tests\"):\n",
" chi2, p_val = pqmass_test(samples_1, samples_2, n_R=n_R)\n",
" chi2_null.append(chi2)\n",
" p_values_null.append(p_val)\n",
"\n",
"chi2_null = np.array(chi2_null)\n",
"p_values_null = np.array(p_values_null)\n",
"\n",
"print(f\"\\nNull test results:\")\n",
"print(f\"Mean chi-squared: {chi2_null.mean():.2f} (expected: {n_R - 1})\")\n",
"print(f\"Std chi-squared: {chi2_null.std():.2f} (expected: {np.sqrt(2 * (n_R - 1)):.2f})\")\n",
"print(f\"Mean p-value: {p_values_null.mean():.3f} (expected: ~0.5 for uniform distribution)\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1400x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data",
"transient": {}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"✓ Validation successful! The chi-squared statistics follow the expected distribution.\n"
]
}
],
"source": [
"# Visualize null test results\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"# Plot histogram of chi-squared values\n",
"axes[0].hist(chi2_null, bins=30, density=True, alpha=0.7, edgecolor='black')\n",
"x = np.linspace(chi2_null.min(), chi2_null.max(), 100)\n",
"axes[0].plot(x, stats.chi2.pdf(x, n_R - 1), 'r-', linewidth=2, \n",
" label=f'Expected χ²({n_R-1})')\n",
"axes[0].axvline(n_R - 1, color='green', linestyle='--', linewidth=2, \n",
" label=f'Expected mean = {n_R-1}')\n",
"axes[0].set_xlabel('Chi-squared statistic')\n",
"axes[0].set_ylabel('Density')\n",
"axes[0].set_title('Distribution of Chi-squared Statistics (Null Test)')\n",
"axes[0].legend()\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"# Plot histogram of p-values (should be uniform)\n",
"axes[1].hist(p_values_null, bins=20, density=True, alpha=0.7, edgecolor='black')\n",
"axes[1].axhline(1.0, color='r', linestyle='--', linewidth=2, \n",
" label='Expected uniform distribution')\n",
"axes[1].set_xlabel('p-value')\n",
"axes[1].set_ylabel('Density')\n",
"axes[1].set_title('Distribution of p-values (Null Test)')\n",
"axes[1].legend()\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n✓ Validation successful! The chi-squared statistics follow the expected distribution.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Comparing Different Distributions (Workflow 1: Core PQMass Test)\n",
"\n",
"Now we test PQMass on its primary use case: comparing samples from **different** distributions. We'll create two Gaussian mixture models that are similar but not identical, simulating a scenario where a generative model produces samples that are close to but not exactly matching the real data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate samples from two different Gaussian mixture models\n",
"def generate_gaussian_mixture_shifted(n_samples, shift=0.0, random_state=None):\n",
" \"\"\"\n",
" Generate samples from a shifted Gaussian mixture model.\n",
" \"\"\"\n",
" if random_state is not None:\n",
" np.random.seed(random_state)\n",
" \n",
" centers = np.array([\n",
" [0, 0],\n",
" [3, 3],\n",
" [-2, 3]\n",
" ]) + shift # Add shift to all centers\n",
" \n",
" samples, _ = make_blobs(n_samples=n_samples, centers=centers, \n",
" cluster_std=0.6, random_state=random_state)\n",
" return samples\n",
"\n",
"# Test with different amounts of shift\n",
"shifts = [0.0, 0.1, 0.3, 0.5, 0.8, 1.0]\n",
"n_samples = 2000\n",
"n_R = 100\n",
"n_tessellations = 20\n",
"\n",
"# Reference samples (\"real data\")\n",
"real_samples = generate_gaussian_mixture_shifted(n_samples, shift=0.0, random_state=42)\n",
"\n",
"results = []\n",
"\n",
"print(\"Testing PQMass with different distribution shifts...\\n\")\n",
"\n",
"for shift in shifts:\n",
" # Generated samples with shift\n",
" gen_samples = generate_gaussian_mixture_shifted(n_samples, shift=shift, random_state=123)\n",
" \n",
" # Run PQMass with multiple tessellations\n",
" result = pqmass_test_multiple_tessellations(\n",
" real_samples, gen_samples, \n",
" n_R=n_R, \n",
" n_tessellations=n_tessellations\n",
" )\n",
" \n",
" results.append(result)\n",
" \n",
" print(f\"Shift = {shift:.2f}:\")\n",
" print(f\" Chi-squared: {result['chi2_mean']:.2f} ± {result['chi2_std']:.2f}\")\n",
" print(f\" p-value: {result['p_value_mean']:.4f} ± {result['p_value_std']:.4f}\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize how PQMass statistic changes with distribution shift\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"chi2_means = [r['chi2_mean'] for r in results]\n",
"chi2_stds = [r['chi2_std'] for r in results]\n",
"p_value_means = [r['p_value_mean'] for r in results]\n",
"p_value_stds = [r['p_value_std'] for r in results]\n",
"\n",
"# Plot chi-squared vs shift\n",
"axes[0].errorbar(shifts, chi2_means, yerr=chi2_stds, marker='o', capsize=5, linewidth=2, markersize=8)\n",
"axes[0].axhline(n_R - 1, color='r', linestyle='--', linewidth=2, \n",
" label=f'Expected under null (χ² = {n_R-1})')\n",
"axes[0].set_xlabel('Distribution Shift', fontsize=12)\n",
"axes[0].set_ylabel('Chi-squared PQM Statistic', fontsize=12)\n",
"axes[0].set_title('PQMass Sensitivity to Distribution Differences', fontsize=13, fontweight='bold')\n",
"axes[0].legend()\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"# Plot p-value vs shift (log scale)\n",
"axes[1].errorbar(shifts, p_value_means, yerr=p_value_stds, marker='o', capsize=5, linewidth=2, markersize=8)\n",
"axes[1].axhline(0.05, color='r', linestyle='--', linewidth=2, label='α = 0.05 threshold')\n",
"axes[1].set_xlabel('Distribution Shift', fontsize=12)\n",
"axes[1].set_ylabel('p-value', fontsize=12)\n",
"axes[1].set_title('Statistical Significance vs Distribution Shift', fontsize=13, fontweight='bold')\n",
"axes[1].set_yscale('log')\n",
"axes[1].legend()\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n✓ PQMass successfully detects differences between distributions!\")\n",
"print(\" As the shift increases, chi-squared increases and p-value decreases.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Visualizing Voronoi Tessellation\n",
"\n",
"Let's visualize how PQMass partitions the sample space using Voronoi cells. This helps understand the core mechanism of the algorithm."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate small sample for visualization\n",
"n_vis = 500\n",
"real_vis = generate_gaussian_mixture_shifted(n_vis, shift=0.0, random_state=42)\n",
"gen_vis = generate_gaussian_mixture_shifted(n_vis, shift=0.3, random_state=123)\n",
"\n",
"# Perform PQMass test and get detailed information\n",
"chi2, p_val, info = pqmass_test(real_vis, gen_vis, n_R=20, return_counts=True)\n",
"\n",
"reference_points = info['reference_points']\n",
"\n",
"# Create visualization\n",
"fig, axes = plt.subplots(1, 2, figsize=(16, 7))\n",
"\n",
"# Left plot: Show reference points and samples\n",
"axes[0].scatter(real_vis[:, 0], real_vis[:, 1], alpha=0.3, s=20, c='blue', label='Real samples')\n",
"axes[0].scatter(gen_vis[:, 0], gen_vis[:, 1], alpha=0.3, s=20, c='orange', label='Generated samples')\n",
"axes[0].scatter(reference_points[:, 0], reference_points[:, 1], \n",
" c='red', s=100, marker='*', edgecolors='black', linewidths=1.5,\n",
" label=f'Reference points (n={len(reference_points)})', zorder=5)\n",
"axes[0].set_xlabel('x1', fontsize=12)\n",
"axes[0].set_ylabel('x2', fontsize=12)\n",
"axes[0].set_title('Voronoi Tessellation: Reference Points', fontsize=13, fontweight='bold')\n",
"axes[0].legend()\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"# Right plot: Show count distributions\n",
"k_p = info['k_p']\n",
"k_q = info['k_q']\n",
"x_pos = np.arange(len(k_p))\n",
"width = 0.35\n",
"\n",
"axes[1].bar(x_pos - width/2, k_p, width, label='Real samples', alpha=0.7, edgecolor='black')\n",
"axes[1].bar(x_pos + width/2, k_q, width, label='Generated samples', alpha=0.7, edgecolor='black')\n",
"axes[1].set_xlabel('Voronoi Cell Index', fontsize=12)\n",
"axes[1].set_ylabel('Count', fontsize=12)\n",
"axes[1].set_title(f'Sample Counts per Voronoi Cell\\nχ² = {chi2:.2f}, p = {p_val:.4f}', \n",
" fontsize=13, fontweight='bold')\n",
"axes[1].legend()\n",
"axes[1].grid(True, alpha=0.3, axis='y')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"Chi-squared statistic: {chi2:.2f}\")\n",
"print(f\"p-value: {p_val:.4f}\")\n",
"print(f\"\\nInterpretation: {'Distributions are significantly different (reject null hypothesis)' if p_val < 0.05 else 'Cannot reject null hypothesis'}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Sampling Method Comparison (Workflow 15)\n",
"\n",
"We demonstrate how PQMass can evaluate different sampling algorithms. We'll compare simple MCMC sampling vs. direct sampling from a known distribution."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define a simple 2D target distribution (Gaussian mixture)\n",
"def log_prob_gaussian_mixture(x):\n",
" \"\"\"\n",
" Log probability of a 2D Gaussian mixture.\n",
" \"\"\"\n",
" # Three Gaussian components\n",
" means = np.array([[0, 0], [3, 3], [-2, 3]])\n",
" sigma = 0.6\n",
" \n",
" log_probs = []\n",
" for mean in means:\n",
" diff = x - mean\n",
" log_p = -0.5 * np.sum(diff**2) / (sigma**2)\n",
" log_probs.append(log_p)\n",
" \n",
" # Log-sum-exp trick for numerical stability\n",
" max_log_p = np.max(log_probs)\n",
" return max_log_p + np.log(np.sum(np.exp(log_probs - max_log_p))) - np.log(3)\n",
"\n",
"# Simple Metropolis-Hastings MCMC sampler\n",
"def mcmc_sample(log_prob_fn, n_samples, n_warmup=1000, step_size=0.3, initial_state=None):\n",
" \"\"\"\n",
" Simple Metropolis-Hastings MCMC sampler.\n",
" \"\"\"\n",
" if initial_state is None:\n",
" current_state = np.random.randn(2)\n",
" else:\n",
" current_state = initial_state.copy()\n",
" \n",
" samples = []\n",
" current_log_prob = log_prob_fn(current_state)\n",
" \n",
" n_accepted = 0\n",
" \n",
" for i in range(n_warmup + n_samples):\n",
" # Propose new state\n",
" proposal = current_state + np.random.randn(2) * step_size\n",
" proposal_log_prob = log_prob_fn(proposal)\n",
" \n",
" # Acceptance ratio\n",
" log_accept_ratio = proposal_log_prob - current_log_prob\n",
" \n",
" # Accept or reject\n",
" if np.log(np.random.rand()) < log_accept_ratio:\n",
" current_state = proposal\n",
" current_log_prob = proposal_log_prob\n",
" n_accepted += 1\n",
" \n",
" # Store sample after warmup\n",
" if i >= n_warmup:\n",
" samples.append(current_state.copy())\n",
" \n",
" acceptance_rate = n_accepted / (n_warmup + n_samples)\n",
" \n",
" return np.array(samples), acceptance_rate\n",
"\n",
"# Generate samples using different methods\n",
"n_samples = 2000\n",
"\n",
"print(\"Generating samples from target distribution...\")\n",
"# True samples (direct sampling)\n",
"true_samples = generate_gaussian_mixture_2d(n_samples, random_state=42)\n",
"\n",
"# MCMC samples\n",
"print(\"Running MCMC sampler...\")\n",
"mcmc_samples, accept_rate = mcmc_sample(log_prob_gaussian_mixture, n_samples, \n",
" n_warmup=500, step_size=0.5)\n",
"print(f\"MCMC acceptance rate: {accept_rate:.2%}\")\n",
"\n",
"# Poor MCMC samples (with bad step size)\n",
"print(\"Running MCMC with poor tuning...\")\n",
"poor_mcmc_samples, poor_accept_rate = mcmc_sample(log_prob_gaussian_mixture, n_samples,\n",
" n_warmup=200, step_size=2.0)\n",
"print(f\"Poor MCMC acceptance rate: {poor_accept_rate:.2%}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compare sampling methods using PQMass\n",
"print(\"\\nEvaluating sampling methods with PQMass...\\n\")\n",
"\n",
"# Well-tuned MCMC vs true samples\n",
"result_mcmc = pqmass_test_multiple_tessellations(\n",
" true_samples, mcmc_samples, n_R=100, n_tessellations=20\n",
")\n",
"\n",
"# Poorly-tuned MCMC vs true samples\n",
"result_poor = pqmass_test_multiple_tessellations(\n",
" true_samples, poor_mcmc_samples, n_R=100, n_tessellations=20\n",
")\n",
"\n",
"print(\"Well-tuned MCMC:\")\n",
"print(f\" Chi-squared: {result_mcmc['chi2_mean']:.2f} ± {result_mcmc['chi2_std']:.2f}\")\n",
"print(f\" p-value: {result_mcmc['p_value_mean']:.4f}\")\n",
"print()\n",
"\n",
"print(\"Poorly-tuned MCMC:\")\n",
"print(f\" Chi-squared: {result_poor['chi2_mean']:.2f} ± {result_poor['chi2_std']:.2f}\")\n",
"print(f\" p-value: {result_poor['p_value_mean']:.4f}\")\n",
"print()\n",
"\n",
"# Visualize\n",
"fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
"\n",
"axes[0].scatter(true_samples[:, 0], true_samples[:, 1], alpha=0.5, s=10)\n",
"axes[0].set_title('True Distribution Samples', fontsize=12, fontweight='bold')\n",
"axes[0].set_xlabel('x1')\n",
"axes[0].set_ylabel('x2')\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"axes[1].scatter(mcmc_samples[:, 0], mcmc_samples[:, 1], alpha=0.5, s=10, color='green')\n",
"axes[1].set_title(f'Well-tuned MCMC\\nχ² = {result_mcmc[\"chi2_mean\"]:.1f}, p = {result_mcmc[\"p_value_mean\"]:.3f}', \n",
" fontsize=12, fontweight='bold')\n",
"axes[1].set_xlabel('x1')\n",
"axes[1].set_ylabel('x2')\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"axes[2].scatter(poor_mcmc_samples[:, 0], poor_mcmc_samples[:, 1], alpha=0.5, s=10, color='red')\n",
"axes[2].set_title(f'Poorly-tuned MCMC\\nχ² = {result_poor[\"chi2_mean\"]:.1f}, p = {result_poor[\"p_value_mean\"]:.3f}', \n",
" fontsize=12, fontweight='bold')\n",
"axes[2].set_xlabel('x1')\n",
"axes[2].set_ylabel('x2')\n",
"axes[2].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n✓ PQMass successfully distinguishes between good and poor sampling algorithms!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Mode Coverage Detection (Workflow 8)\n",
"\n",
"One important application of PQMass is detecting when a generative model fails to capture all modes of the true distribution (mode collapse). We'll simulate this by creating samples that are missing one of the clusters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate samples with missing modes\n",
"def generate_with_missing_modes(n_samples, n_modes_to_drop=0, random_state=None):\n",
" \"\"\"\n",
" Generate samples from Gaussian mixture with some modes dropped.\n",
" \"\"\"\n",
" if random_state is not None:\n",
" np.random.seed(random_state)\n",
" \n",
" all_centers = np.array([\n",
" [0, 0],\n",
" [3, 3],\n",
" [-2, 3]\n",
" ])\n",
" \n",
" # Drop specified number of modes\n",
" if n_modes_to_drop > 0:\n",
" centers = all_centers[:-n_modes_to_drop]\n",
" else:\n",
" centers = all_centers\n",
" \n",
" samples, _ = make_blobs(n_samples=n_samples, centers=centers, \n",
" cluster_std=0.6, random_state=random_state)\n",
" return samples\n",
"\n",
"# Test with different numbers of dropped modes\n",
"n_samples = 2000\n",
"real_samples = generate_with_missing_modes(n_samples, n_modes_to_drop=0, random_state=42)\n",
"\n",
"modes_to_drop = [0, 1, 2]\n",
"mode_results = []\n",
"\n",
"print(\"Testing mode coverage detection...\\n\")\n",
"\n",
"for n_drop in modes_to_drop:\n",
" gen_samples = generate_with_missing_modes(n_samples, n_modes_to_drop=n_drop, random_state=123)\n",
" \n",
" result = pqmass_test_multiple_tessellations(\n",
" real_samples, gen_samples, n_R=100, n_tessellations=20\n",
" )\n",
" \n",
" mode_results.append(result)\n",
" \n",
" print(f\"Dropped {n_drop} mode(s):\")\n",
" print(f\" Chi-squared: {result['chi2_mean']:.2f} ± {result['chi2_std']:.2f}\")\n",
" print(f\" p-value: {result['p_value_mean']:.4f}\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize mode coverage results\n",
"fig, axes = plt.subplots(2, 2, figsize=(14, 12))\n",
"\n",
"# Plot samples with different numbers of dropped modes\n",
"for idx, n_drop in enumerate(modes_to_drop):\n",
" if idx < 3:\n",
" row = idx // 2\n",
" col = idx % 2\n",
" gen_samples = generate_with_missing_modes(n_samples, n_modes_to_drop=n_drop, random_state=123)\n",
" \n",
" axes[row, col].scatter(real_samples[:, 0], real_samples[:, 1], \n",
" alpha=0.3, s=10, c='blue', label='Real (3 modes)')\n",
" axes[row, col].scatter(gen_samples[:, 0], gen_samples[:, 1], \n",
" alpha=0.3, s=10, c='orange', label=f'Generated ({3-n_drop} modes)')\n",
" axes[row, col].set_title(f'Dropped {n_drop} mode(s)\\nχ² = {mode_results[idx][\"chi2_mean\"]:.1f}', \n",
" fontsize=12, fontweight='bold')\n",
" axes[row, col].set_xlabel('x1')\n",
" axes[row, col].set_ylabel('x2')\n",
" axes[row, col].legend()\n",
" axes[row, col].grid(True, alpha=0.3)\n",
"\n",
"# Plot chi-squared vs number of dropped modes\n",
"chi2_vals = [r['chi2_mean'] for r in mode_results]\n",
"chi2_errs = [r['chi2_std'] for r in mode_results]\n",
"\n",
"axes[1, 1].errorbar(modes_to_drop, chi2_vals, yerr=chi2_errs, \n",
" marker='o', capsize=5, linewidth=2, markersize=10)\n",
"axes[1, 1].axhline(99, color='r', linestyle='--', linewidth=2, label='Expected under null')\n",
"axes[1, 1].set_xlabel('Number of Dropped Modes', fontsize=12)\n",
"axes[1, 1].set_ylabel('Chi-squared PQM Statistic', fontsize=12)\n",
"axes[1, 1].set_title('Mode Coverage Detection', fontsize=12, fontweight='bold')\n",
"axes[1, 1].set_xticks(modes_to_drop)\n",
"axes[1, 1].legend()\n",
"axes[1, 1].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n✓ PQMass successfully detects missing modes!\")\n",
"print(\" The chi-squared statistic increases as more modes are dropped.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. MNIST Generative Model Evaluation (Simplified Workflow 3)\n",
"\n",
"Now we demonstrate PQMass on a more realistic task: evaluating a simple generative model on MNIST digits. Due to computational constraints, we'll use a very small subset of MNIST and a simple autoencoder.\n",
"\n",
"**Note:** This is a simplified demonstration. In practice, you would:\n",
"- Train for many more epochs\n",
"- Use larger sample sizes\n",
"- Use more sophisticated models (VAE, Diffusion models)\n",
"- Track chi-squared values over training to monitor progress"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Simple VAE for MNIST\n",
"class SimpleVAE(nn.Module):\n",
" def __init__(self, latent_dim=20):\n",
" super(SimpleVAE, self).__init__()\n",
" \n",
" # Encoder\n",
" self.encoder = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(28*28, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 128),\n",
" nn.ReLU()\n",
" )\n",
" \n",
" self.fc_mu = nn.Linear(128, latent_dim)\n",
" self.fc_logvar = nn.Linear(128, latent_dim)\n",
" \n",
" # Decoder\n",
" self.decoder = nn.Sequential(\n",
" nn.Linear(latent_dim, 128),\n",
" nn.ReLU(),\n",
" nn.Linear(128, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 28*28),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" def encode(self, x):\n",
" h = self.encoder(x)\n",
" return self.fc_mu(h), self.fc_logvar(h)\n",
" \n",
" def reparameterize(self, mu, logvar):\n",
" std = torch.exp(0.5 * logvar)\n",
" eps = torch.randn_like(std)\n",
" return mu + eps * std\n",
" \n",
" def decode(self, z):\n",
" return self.decoder(z).view(-1, 1, 28, 28)\n",
" \n",
" def forward(self, x):\n",
" mu, logvar = self.encode(x)\n",
" z = self.reparameterize(mu, logvar)\n",
" return self.decode(z), mu, logvar\n",
" \n",
" def sample(self, num_samples, device='cpu'):\n",
" z = torch.randn(num_samples, self.fc_mu.out_features).to(device)\n",
" samples = self.decode(z)\n",
" return samples\n",
"\n",
"def vae_loss(recon_x, x, mu, logvar):\n",
" BCE = F.binary_cross_entropy(recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')\n",
" KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
" return BCE + KLD\n",
"\n",
"print(\"VAE model defined successfully!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load a small subset of MNIST\n",
"print(\"Loading MNIST dataset...\")\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
"])\n",
"\n",
"# Load MNIST\n",
"train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)\n",
"test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)\n",
"\n",
"# Use small subset for fast training\n",
"subset_size = 5000 # Use small subset for demonstration\n",
"train_subset = torch.utils.data.Subset(train_dataset, range(subset_size))\n",
"train_loader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True)\n",
"\n",
"test_subset = torch.utils.data.Subset(test_dataset, range(1000))\n",
"test_loader = torch.utils.data.DataLoader(test_subset, batch_size=1000, shuffle=False)\n",
"\n",
"print(f\"Training samples: {len(train_subset)}\")\n",
"print(f\"Test samples: {len(test_subset)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train a simple VAE\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Using device: {device}\")\n",
"\n",
"model = SimpleVAE(latent_dim=20).to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"# Train for a few epochs (keeping it fast for demonstration)\n",
"n_epochs = 5\n",
"print(f\"\\nTraining VAE for {n_epochs} epochs...\")\n",
"\n",
"model.train()\n",
"for epoch in range(n_epochs):\n",
" total_loss = 0\n",
" for batch_idx, (data, _) in enumerate(train_loader):\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" \n",
" recon_batch, mu, logvar = model(data)\n",
" loss = vae_loss(recon_batch, data, mu, logvar)\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" total_loss += loss.item()\n",
" \n",
" avg_loss = total_loss / len(train_loader.dataset)\n",
" print(f\"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss:.4f}\")\n",
"\n",
"print(\"\\nTraining complete!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate samples and evaluate with PQMass\n",
"model.eval()\n",
"\n",
"# Generate samples from VAE\n",
"n_gen_samples = 1000\n",
"with torch.no_grad():\n",
" generated_samples = model.sample(n_gen_samples, device=device).cpu().numpy()\n",
"\n",
"# Get real test samples\n",
"real_samples = next(iter(test_loader))[0].numpy()\n",
"\n",
"# Flatten images for PQMass\n",
"generated_flat = generated_samples.reshape(n_gen_samples, -1)\n",
"real_flat = real_samples.reshape(len(real_samples), -1)\n",
"\n",
"print(f\"Generated samples shape: {generated_flat.shape}\")\n",
"print(f\"Real samples shape: {real_flat.shape}\")\n",
"\n",
"# Apply PQMass in pixel space\n",
"print(\"\\nEvaluating with PQMass...\")\n",
"result = pqmass_test_multiple_tessellations(\n",
" real_flat, generated_flat, \n",
" n_R=50, # Use fewer reference points for high-dimensional data\n",
" n_tessellations=10\n",
")\n",
"\n",
"print(f\"\\nPQMass Results (784-dimensional pixel space):\")\n",
"print(f\"Chi-squared: {result['chi2_mean']:.2f} ± {result['chi2_std']:.2f}\")\n",
"print(f\"p-value: {result['p_value_mean']:.4f}\")\n",
"print(f\"\\nInterpretation: The generated samples {'significantly differ from' if result['p_value_mean'] < 0.05 else 'are statistically similar to'} real MNIST digits.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize generated samples\n",
"fig, axes = plt.subplots(2, 10, figsize=(15, 3))\n",
"\n",
"for i in range(10):\n",
" axes[0, i].imshow(real_samples[i, 0], cmap='gray')\n",
" axes[0, i].axis('off')\n",
" if i == 0:\n",
" axes[0, i].set_title('Real', fontsize=10)\n",
" \n",
" axes[1, i].imshow(generated_samples[i, 0], cmap='gray')\n",
" axes[1, i].axis('off')\n",
" if i == 0:\n",
" axes[1, i].set_title('Generated', fontsize=10)\n",
"\n",
"plt.suptitle(f'MNIST: Real vs Generated Samples\\nPQMass χ² = {result[\"chi2_mean\"]:.1f}, p = {result[\"p_value_mean\"]:.4f}', \n",
" fontsize=13, fontweight='bold')\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n✓ Successfully applied PQMass to high-dimensional image data!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9. Comparing with Baseline Metrics\n",
"\n",
"PQMass can be compared with other sample-based metrics. Here we implement simple versions of Maximum Mean Discrepancy (MMD) for comparison."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_mmd_rbf(X, Y, gamma=1.0):\n",
" \"\"\"\n",
" Compute Maximum Mean Discrepancy with RBF kernel.\n",
" \n",
" MMD²(X,Y) = E[k(x,x')] - 2E[k(x,y)] + E[k(y,y')]\n",
" where k is the RBF kernel.\n",
" \"\"\"\n",
" # RBF kernel\n",
" def rbf_kernel(X, Y, gamma):\n",
" # Compute pairwise squared distances\n",
" XX = np.sum(X**2, axis=1).reshape(-1, 1)\n",
" YY = np.sum(Y**2, axis=1).reshape(1, -1)\n",
" XY = X @ Y.T\n",
" sq_distances = XX - 2*XY + YY\n",
" return np.exp(-gamma * sq_distances)\n",
" \n",
" m = len(X)\n",
" n = len(Y)\n",
" \n",
" # Compute kernel matrices\n",
" K_XX = rbf_kernel(X, X, gamma)\n",
" K_YY = rbf_kernel(Y, Y, gamma)\n",
" K_XY = rbf_kernel(X, Y, gamma)\n",
" \n",
" # MMD² statistic\n",
" mmd_sq = (np.sum(K_XX) - np.trace(K_XX)) / (m * (m-1))\n",
" mmd_sq += (np.sum(K_YY) - np.trace(K_YY)) / (n * (n-1))\n",
" mmd_sq -= 2 * np.sum(K_XY) / (m * n)\n",
" \n",
" return np.sqrt(max(mmd_sq, 0)) # Take sqrt and ensure non-negative\n",
"\n",
"print(\"MMD baseline implemented!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compare PQMass and MMD on 2D Gaussian mixtures\n",
"shifts = [0.0, 0.1, 0.2, 0.3, 0.5, 0.8, 1.0]\n",
"n_samples = 1000\n",
"\n",
"pqmass_results = []\n",
"mmd_results = []\n",
"\n",
"real_samples = generate_gaussian_mixture_shifted(n_samples, shift=0.0, random_state=42)\n",
"\n",
"print(\"Comparing PQMass and MMD...\\n\")\n",
"\n",
"for shift in tqdm(shifts, desc=\"Testing shifts\"):\n",
" gen_samples = generate_gaussian_mixture_shifted(n_samples, shift=shift, random_state=123)\n",
" \n",
" # PQMass\n",
" pqm_result = pqmass_test_multiple_tessellations(\n",
" real_samples, gen_samples, n_R=100, n_tessellations=10\n",
" )\n",
" pqmass_results.append(pqm_result['chi2_mean'])\n",
" \n",
" # MMD with RBF kernel\n",
" mmd = compute_mmd_rbf(real_samples, gen_samples, gamma=0.5)\n",
" mmd_results.append(mmd)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize comparison\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"# PQMass\n",
"axes[0].plot(shifts, pqmass_results, marker='o', linewidth=2, markersize=8, label='PQMass χ²')\n",
"axes[0].axhline(99, color='r', linestyle='--', linewidth=2, label='Expected under null')\n",
"axes[0].set_xlabel('Distribution Shift', fontsize=12)\n",
"axes[0].set_ylabel('Chi-squared Statistic', fontsize=12)\n",
"axes[0].set_title('PQMass vs Distribution Shift', fontsize=13, fontweight='bold')\n",
"axes[0].legend()\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"# MMD\n",
"axes[1].plot(shifts, mmd_results, marker='s', linewidth=2, markersize=8, color='green', label='MMD (RBF kernel)')\n",
"axes[1].axhline(0, color='r', linestyle='--', linewidth=2, label='Expected under null')\n",
"axes[1].set_xlabel('Distribution Shift', fontsize=12)\n",
"axes[1].set_ylabel('MMD Value', fontsize=12)\n",
"axes[1].set_title('MMD vs Distribution Shift', fontsize=13, fontweight='bold')\n",
"axes[1].legend()\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n✓ Both PQMass and MMD successfully detect distribution differences!\")\n",
"print(\" Both metrics increase monotonically with distribution shift.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10. Summary and Scaling Guidance\n",
"\n",
"### What We Demonstrated\n",
"\n",
"This notebook walked through the core PQMass methodology using small-scale, efficient examples:\n",
"\n",
"1. **Null test validation**: Verified that PQMass chi-squared statistics follow the expected distribution\n",
"2. **Distribution comparison**: Showed PQMass can detect differences between distributions\n",
"3. **Voronoi visualization**: Illustrated how PQMass partitions space\n",
"4. **Sampling method comparison**: Evaluated MCMC samplers\n",
"5. **Mode coverage detection**: Detected missing modes in distributions\n",
"6. **MNIST evaluation**: Applied PQMass to high-dimensional image data\n",
"7. **Baseline comparison**: Compared with MMD metric\n",
"\n",
"### Key Takeaways\n",
"\n",
"- **PQMass is statistically rigorous**: Provides p-values from chi-squared tests\n",
"- **No auxiliary models needed**: Unlike FID/FLD, requires no pretrained networks\n",
"- **Works in pixel space**: Can operate on high-dimensional data without feature extraction\n",
"- **Flexible**: Works with any distance metric and data modality\n",
"- **Efficient**: Computational cost scales well with dimensionality\n",
"\n",
"### Scaling to Production Use\n",
"\n",
"This notebook used small-scale examples to run efficiently. For production use:\n",
"\n",
"**Sample sizes:**\n",
"- We used: 1,000-2,000 samples per distribution\n",
"- Production: 10,000-50,000 samples for robust statistics\n",
"\n",
"**Reference points (n_R):**\n",
"- We used: 50-100 reference points\n",
"- Production: 100-500 reference points for more fine-grained tessellation\n",
"- Higher n_R gives more statistical power but increases computation\n",
"\n",
"**Tessellation repetitions:**\n",
"- We used: 10-20 repetitions\n",
"- Production: 20-100 repetitions to reduce variance from random tessellations\n",
"\n",
"**Model training:**\n",
"- We used: 5 epochs, 5,000 samples\n",
"- Production: 100+ epochs, full datasets\n",
"- Track PQMass during training to monitor progress\n",
"\n",
"**Computational resources:**\n",
"- We used: CPU-only, 4GB RAM\n",
"- Production: GPU for model training, more memory for larger datasets\n",
"- Distance computations can be parallelized\n",
"\n",
"**Distance metrics:**\n",
"- For images: Euclidean (L2) in pixel space or feature space\n",
"- For sequences: Levenshtein distance, dynamic time warping\n",
"- For tabular data: L2 after normalization\n",
"\n",
"### Additional Workflows Not Covered\n",
"\n",
"Due to resource constraints, we didn't implement all workflows from the paper. These would require:\n",
"\n",
"- **Astrophysics images**: Large pre-trained diffusion models, high-res images\n",
"- **Protein sequences**: Specialized distance metrics, large sequence datasets\n",
"- **Tabular data generation**: CTGAN training, Adult Census dataset\n",
"- **Human judgment correlation**: Pre-trained generative models on CIFAR-10/FFHQ\n",
"- **Novelty detection**: Training multiple models, computing train-test gaps\n",
"\n",
"### Further Reading\n",
"\n",
"See the paper for:\n",
"- Theoretical guarantees and proofs\n",
"- Detailed experimental results\n",
"- Comparison with more baseline metrics\n",
"- Applications to various data modalities\n",
"- Permutation test extensions for small sample sizes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"This notebook provided a comprehensive, educational introduction to **PQMass** - a powerful, statistically rigorous method for evaluating generative models. The key innovation is using Voronoi tessellation to partition space and applying chi-squared tests to multinomial count distributions.\n",
"\n",
"**Next steps for researchers:**\n",
"1. Apply PQMass to your own generative models\n",
"2. Experiment with different distance metrics for your data modality\n",
"3. Track PQMass during model training to monitor progress\n",
"4. Compare PQMass with domain-specific metrics\n",
"5. Scale up to full datasets on your own infrastructure\n",
"\n",
"**Citation:**\n",
"```\n",
"@inproceedings{lemos2025pqmass,\n",
" title={PQMass: Probabilistic Assessment of the Quality of Generative Models using Probability Mass Estimation},\n",
" author={Lemos, Pablo and Sharief, Sammy and Malkin, Nikolay and Salhi, Salma and Stone, Connor and Perreault-Levasseur, Laurence and Hezaveh, Yashar},\n",
" booktitle={International Conference on Learning Representations},\n",
" year={2025}\n",
"}\n",
"```\n",
"\n",
"**Code:** https://github.com/Ciela-Institute/PQM"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment