Skip to content

Instantly share code, notes, and snippets.

@hasithv
Created December 11, 2025 00:36
Show Gist options
  • Select an option

  • Save hasithv/104a46cbfb5939e17eeb8ba51e92c5e7 to your computer and use it in GitHub Desktop.

Select an option

Save hasithv/104a46cbfb5939e17eeb8ba51e92c5e7 to your computer and use it in GitHub Desktop.
import math
import numpy as np
import torch
from transformers import AutoModel
from PIL import Image
import imageio
from tqdm import tqdm
# ----------------- CONFIG -----------------
MODEL_NAME = "gpt2" # change to another HF AutoModel LLM if you want
NUM_FRAMES = 48 # number of frames around the circle
HEIGHT = 128 * 2 # pixel grid height
WIDTH = 256 * 2 # pixel grid width
RECT_HALF_WIDTH = 5.0 # rectangle is 1 x 2 in plane coords
RECT_HALF_HEIGHT = 10.0
GIF_PATH = "llm_art_pca.gif"
FRAME_DURATION_MS = 80
SEED = 0
SAMPLES_PER_FRAME = 1024 # how many residuals to sample per frame for PCA
# ------------------------------------------
def random_unit_vector(dim, device):
v = torch.randn(dim, device=device)
v = v / v.norm()
return v
def main():
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda:1")
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()
d_model = model.config.hidden_size
# 1–2: two random unit vectors, then orthonormalize to span a plane
u = random_unit_vector(d_model, device)
v_raw = random_unit_vector(d_model, device)
v = v_raw - torch.dot(u, v_raw) * u
v = v / v.norm()
# 3: precompute grid in the (u, v) plane for the rectangle
xs = torch.linspace(-RECT_HALF_WIDTH, RECT_HALF_WIDTH, WIDTH, device=device)
ys = torch.linspace(-RECT_HALF_HEIGHT, RECT_HALF_HEIGHT, HEIGHT, device=device)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") # (H, W)
plane_offsets = grid_x[..., None] * u + grid_y[..., None] * v # (H, W, d_model)
frames = []
with torch.no_grad():
# ------------ PASS 1: collect samples for PCA ------------
sample_list = []
for frame_idx in tqdm(range(NUM_FRAMES), desc="Pass 1/2 (collect PCA samples)"):
theta = 2.0 * math.pi * frame_idx / NUM_FRAMES
# point on the unit circle in the (u, v) plane
center_vec = math.cos(theta) * u + math.sin(theta) * v # (d_model,)
# center of rectangle + all grid offsets
points = center_vec + plane_offsets # (H, W, d_model)
points_flat = points.reshape(-1, d_model) # (H*W, d_model)
inputs_embeds = points_flat.unsqueeze(1) # (B, 1, d_model)
outputs = model(inputs_embeds=inputs_embeds, output_hidden_states=True)
hidden_states = outputs.hidden_states # tuple, len = num_layers + 1
mid_layer_idx = len(hidden_states) // 2
residual_vectors = hidden_states[mid_layer_idx][:, 0, :] # (B, d_model)
# sample a subset of residuals for PCA
b = residual_vectors.size(0)
k = min(SAMPLES_PER_FRAME, b)
idx = torch.randint(0, b, (k,), device=device)
sample = residual_vectors[idx].cpu()
sample_list.append(sample)
samples = torch.cat(sample_list, dim=0) # (N_s, d_model)
print(f"Collected {samples.size(0)} samples for PCA")
# ------------ PCA on sampled residuals ------------
# center the data
mean_vec = samples.mean(dim=0) # (d_model,)
X_centered = samples - mean_vec # (N_s, d_model)
# move to device for PCA
Xc_device = X_centered.to(device)
# low-rank PCA: we only need first 3 principal components
# Xc_device has shape (N_s, d_model)
# V has shape (d_model, d_model); we take first 3 PCs
U, S, V = torch.pca_lowrank(Xc_device, q=3, center=False)
pcs = V[:, :3] # (d_model, 3)
# use sample projections to set global min/max for each PC
sample_proj = Xc_device @ pcs # (N_s, 3)
sample_proj_cpu = sample_proj.cpu()
mins = sample_proj_cpu.min(dim=0).values # (3,)
maxs = sample_proj_cpu.max(dim=0).values # (3,)
# push PCA params to device for pass 2
mean_vec = mean_vec.to(device)
pcs = pcs.to(device)
mins = mins.to(device)
maxs = maxs.to(device)
# ------------ PASS 2: generate frames using PCA projections ------------
for frame_idx in tqdm(range(NUM_FRAMES), desc="Pass 2/2 (render frames)"):
theta = 2.0 * math.pi * frame_idx / NUM_FRAMES
center_vec = math.cos(theta) * u + math.sin(theta) * v # (d_model,)
points = center_vec + plane_offsets # (H, W, d_model)
points_flat = points.reshape(-1, d_model) # (H*W, d_model)
inputs_embeds = points_flat.unsqueeze(1) # (B, 1, d_model)
outputs = model(inputs_embeds=inputs_embeds, output_hidden_states=True)
hidden_states = outputs.hidden_states
mid_layer_idx = len(hidden_states) // 2
residual_vectors = hidden_states[mid_layer_idx][:, 0, :] # (B, d_model)
# project onto first 3 PCs
centered = residual_vectors - mean_vec # (B, d_model)
proj = centered @ pcs # (B, 3)
# normalize each PC to [0, 1] using global min/max, then map to [0, 255]
denom = (maxs - mins)
denom[denom == 0] = 1e-8
proj_norm = (proj - mins) / denom
proj_norm = proj_norm.clamp(0.0, 1.0)
rgb = (proj_norm * 255.0).clamp(0.0, 255.0)
rgb_np = rgb.reshape(HEIGHT, WIDTH, 3).detach().cpu().numpy().astype(np.uint8)
img = Image.fromarray(rgb_np, mode="RGB")
frames.append(img)
# Looping GIF; GIF itself loops, and the path is circular, so start==end visually
frames[0].save(
GIF_PATH,
save_all=True,
append_images=frames[1:],
loop=0,
duration=FRAME_DURATION_MS,
)
print(f"Saved GIF to {GIF_PATH}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment