Skip to content

Instantly share code, notes, and snippets.

@akshayka
Created November 20, 2025 16:09
Show Gist options
  • Select an option

  • Save akshayka/e0937cf59d0cdc4f61a0811b0efaf600 to your computer and use it in GitHub Desktop.

Select an option

Save akshayka/e0937cf59d0cdc4f61a0811b0efaf600 to your computer and use it in GitHub Desktop.
embedding_mnist.py
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "altair",
# "marimo>=0.17.0",
# "matplotlib",
# "numpy",
# "pandas",
# "polars",
# "pyarrow",
# "pymde",
# "pyzmq",
# "torch",
# ]
# ///
import marimo
__generated_with = "0.17.8"
app = marimo.App(width="columns")
@app.cell
def _():
import marimo as mo
import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymde
import torch
import polars as pl
import pyarrow
mnist = pymde.datasets.MNIST()
return alt, mnist, mo, pd, plt, pymde, torch
@app.cell
def _(mnist, mo, pymde, torch):
def compute_embedding(embedding_dim, constraint=None, quadratic=False):
mo.output.append(
mo.md("Your embedding is being computed ... hang tight!").callout(
kind="warn"
)
)
constraint = constraint if constraint is not None else pymde.Standardized()
mde = pymde.preserve_neighbors(
mnist.data,
attractive_penalty=pymde.penalties.Log1p
if not quadratic
else pymde.penalties.Quadratic,
repulsive_penalty=pymde.penalties.Log if not quadratic else None,
embedding_dim=embedding_dim,
constraint=constraint,
device="cuda" if torch.cuda.is_available() else "cpu",
verbose=True,
)
X = mde.embed(verbose=True)
mo.output.clear()
return X, mde.distortions()
return (compute_embedding,)
@app.cell
def _(mnist, pd, torch):
def dataframe_from_embedding(embedding, samples=20000):
indices = torch.randperm(mnist.data.shape[0])[:samples]
embedding_np = embedding.numpy()[indices]
df = pd.DataFrame(
{
"index": indices,
"x": embedding_np[:, 0],
"y": embedding_np[:, 1],
"digit": mnist.attributes["digits"].numpy()[indices],
}
)
return df
return (dataframe_from_embedding,)
@app.cell
def _(alt):
def scatter(df, size=4):
return (
alt.Chart(df)
.mark_circle(size=size)
.encode(
x=alt.X("x:Q").scale(domain=(-2.5, 2.5)),
y=alt.Y("y:Q").scale(domain=(-2.5, 2.5)),
color=alt.Color("digit:N"),
)
.properties(width=500, height=500)
)
return (scatter,)
@app.cell
def _(mnist, plt):
def show_images(indices, max_images=10):
indices = indices[:max_images]
images = mnist.data.reshape((-1, 28, 28))[indices]
fig, axes = plt.subplots(1, len(indices))
fig.set_size_inches(12.5, 1.5)
if len(indices) > 1:
for im, ax in zip(images, axes.flat):
ax.imshow(im, cmap="gray")
ax.set_yticks([])
ax.set_xticks([])
else:
axes.imshow(images[0], cmap="gray")
axes.set_yticks([])
axes.set_xticks([])
plt.tight_layout()
return fig
return (show_images,)
@app.cell
def _(mo):
mo.md(r"""
# Embedding MNIST
""")
return
@app.cell
def _(mo):
mo.md(r"""
Here's an **embedding of MNIST**: each point represents a digit,
with similar digits close to each other.
""")
return
@app.cell
def _(mo):
quadratic = mo.ui.switch(value=False, label="Spectral embedding?")
quadratic
return (quadratic,)
@app.cell
def _(df, mo, scatter):
chart = mo.ui.altair_chart(scatter(df))
chart
return (chart,)
@app.cell
def _(chart, mo):
table = mo.ui.table(chart.value)
return (table,)
@app.cell
def _(chart, mo, show_images, table):
# mo.stop() prevents this cell from running if the chart has
# no selection
mo.stop(not len(chart.value))
# show 10 images: either the first 10 from the selection, or the first ten
# selected in the table
selected_images = (
show_images(list(chart.value["index"]))
if not len(table.value)
else show_images(list(table.value["index"]))
)
mo.md(
f"""
**Here's a preview of the images you've selected**:
{mo.as_html(selected_images)}
Here's all the data you've selected.
{table}
"""
)
return
@app.cell
def _(chart):
selection = chart.value
return
@app.cell
def _(mo):
_df = mo.sql(
f"""
SELECT digit, COUNT(*) AS frequency
FROM selection
GROUP BY digit
ORDER BY frequency DESC
LIMIT 1;
"""
)
return
@app.cell
def _(mo, most_common_digit):
_df = mo.sql(
f"""
SELECT * FROM selection where digit != {most_common_digit[0, 0]};
"""
)
return
@app.cell
def _(chart):
chart.save("embedding.html")
return
@app.cell
def _(dataframe_from_embedding, embedding):
df = dataframe_from_embedding(embedding)
return (df,)
@app.cell
def _(pymde):
embedding_dimension = 2
constraint = pymde.Standardized()
return constraint, embedding_dimension
@app.cell
def _(args, compute_embedding, constraint, embedding_dimension, mo):
with mo.persistent_cache("embedding"):
embedding, _ = compute_embedding(
embedding_dimension, constraint, args.quadratic
)
return (embedding,)
@app.cell
def _(quadratic):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-q", "--quadratic", action="store_true", default=quadratic.value
)
try:
args = parser.parse_args()
except BaseException:
class _Namespace: ...
args = _Namespace()
args.quadratic = False
return (args,)
if __name__ == "__main__":
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment