Created
November 20, 2025 16:23
-
-
Save akshayka/0363f27ff94e52c9d1f3f10b82f28d90 to your computer and use it in GitHub Desktop.
embedding_mnist.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.13" | |
| # dependencies = [ | |
| # "altair", | |
| # "marimo>=0.17.0", | |
| # "matplotlib", | |
| # "numpy", | |
| # "pandas", | |
| # "polars", | |
| # "pyarrow", | |
| # "pymde", | |
| # "pyzmq", | |
| # "torch", | |
| # ] | |
| # /// | |
| import marimo | |
| __generated_with = "0.17.8" | |
| app = marimo.App(width="columns") | |
| @app.cell | |
| def _(): | |
| import marimo as mo | |
| import altair as alt | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import pymde | |
| import torch | |
| import polars as pl | |
| import pyarrow | |
| mnist = pymde.datasets.MNIST() | |
| return alt, mnist, mo, pd, plt, pymde, torch | |
| @app.cell | |
| def _(mnist, mo, pymde, torch): | |
| def compute_embedding(embedding_dim, constraint=None, quadratic=False): | |
| mo.output.append( | |
| mo.md("Your embedding is being computed ... hang tight!").callout( | |
| kind="warn" | |
| ) | |
| ) | |
| constraint = constraint if constraint is not None else pymde.Standardized() | |
| mde = pymde.preserve_neighbors( | |
| mnist.data, | |
| attractive_penalty=pymde.penalties.Log1p | |
| if not quadratic | |
| else pymde.penalties.Quadratic, | |
| repulsive_penalty=pymde.penalties.Log if not quadratic else None, | |
| embedding_dim=embedding_dim, | |
| constraint=constraint, | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| verbose=True, | |
| ) | |
| X = mde.embed(verbose=True) | |
| mo.output.clear() | |
| return X, mde.distortions() | |
| return (compute_embedding,) | |
| @app.cell | |
| def _(mnist, pd, torch): | |
| def dataframe_from_embedding(embedding, samples=20000): | |
| indices = torch.randperm(mnist.data.shape[0])[:samples] | |
| embedding_np = embedding.numpy()[indices] | |
| df = pd.DataFrame( | |
| { | |
| "index": indices, | |
| "x": embedding_np[:, 0], | |
| "y": embedding_np[:, 1], | |
| "digit": mnist.attributes["digits"].numpy()[indices], | |
| } | |
| ) | |
| return df | |
| return (dataframe_from_embedding,) | |
| @app.cell | |
| def _(alt): | |
| def scatter(df, size=4): | |
| return ( | |
| alt.Chart(df) | |
| .mark_circle(size=size) | |
| .encode( | |
| x=alt.X("x:Q").scale(domain=(-2.5, 2.5)), | |
| y=alt.Y("y:Q").scale(domain=(-2.5, 2.5)), | |
| color=alt.Color("digit:N"), | |
| ) | |
| .properties(width=500, height=500) | |
| ) | |
| return (scatter,) | |
| @app.cell | |
| def _(mnist, plt): | |
| def show_images(indices, max_images=10): | |
| indices = indices[:max_images] | |
| images = mnist.data.reshape((-1, 28, 28))[indices] | |
| fig, axes = plt.subplots(1, len(indices)) | |
| fig.set_size_inches(12.5, 1.5) | |
| if len(indices) > 1: | |
| for im, ax in zip(images, axes.flat): | |
| ax.imshow(im, cmap="gray") | |
| ax.set_yticks([]) | |
| ax.set_xticks([]) | |
| else: | |
| axes.imshow(images[0], cmap="gray") | |
| axes.set_yticks([]) | |
| axes.set_xticks([]) | |
| plt.tight_layout() | |
| return fig | |
| return (show_images,) | |
| @app.cell | |
| def _(mo): | |
| mo.md(r""" | |
| # Embedding MNIST | |
| """) | |
| return | |
| @app.cell | |
| def _(mo): | |
| mo.md(r""" | |
| Here's an **embedding of MNIST**: each point represents a digit, | |
| with similar digits close to each other. | |
| """) | |
| return | |
| @app.cell | |
| def _(mo): | |
| quadratic = mo.ui.switch(value=False, label="Spectral embedding?") | |
| quadratic | |
| return (quadratic,) | |
| @app.cell | |
| def _(df, mo, scatter): | |
| chart = mo.ui.altair_chart(scatter(df)) | |
| chart | |
| return (chart,) | |
| @app.cell | |
| def _(chart, mo): | |
| table = mo.ui.table(chart.value) | |
| return (table,) | |
| @app.cell | |
| def _(chart, mo, show_images, table): | |
| # mo.stop() prevents this cell from running if the chart has | |
| # no selection | |
| mo.stop(not len(chart.value)) | |
| # show 10 images: either the first 10 from the selection, or the first ten | |
| # selected in the table | |
| selected_images = ( | |
| show_images(list(chart.value["index"])) | |
| if not len(table.value) | |
| else show_images(list(table.value["index"])) | |
| ) | |
| mo.md( | |
| f""" | |
| **Here's a preview of the images you've selected**: | |
| {mo.as_html(selected_images)} | |
| Here's all the data you've selected. | |
| {table} | |
| """ | |
| ) | |
| return | |
| @app.cell | |
| def _(chart): | |
| selection = chart.value | |
| return | |
| @app.cell | |
| def _(chart): | |
| chart.save("embedding.html") | |
| return | |
| @app.cell | |
| def _(dataframe_from_embedding, embedding): | |
| df = dataframe_from_embedding(embedding) | |
| return (df,) | |
| @app.cell | |
| def _(pymde): | |
| embedding_dimension = 2 | |
| constraint = pymde.Standardized() | |
| return constraint, embedding_dimension | |
| @app.cell | |
| def _(args, compute_embedding, constraint, embedding_dimension, mo): | |
| with mo.persistent_cache("embedding"): | |
| embedding, _ = compute_embedding( | |
| embedding_dimension, constraint, args.quadratic | |
| ) | |
| return (embedding,) | |
| @app.cell | |
| def _(quadratic): | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-q", "--quadratic", action="store_true", default=quadratic.value | |
| ) | |
| try: | |
| args = parser.parse_args() | |
| except BaseException: | |
| class _Namespace: ... | |
| args = _Namespace() | |
| args.quadratic = False | |
| return (args,) | |
| if __name__ == "__main__": | |
| app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment