Created
December 11, 2025 00:36
-
-
Save hasithv/104a46cbfb5939e17eeb8ba51e92c5e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModel | |
| from PIL import Image | |
| import imageio | |
| from tqdm import tqdm | |
| # ----------------- CONFIG ----------------- | |
| MODEL_NAME = "gpt2" # change to another HF AutoModel LLM if you want | |
| NUM_FRAMES = 48 # number of frames around the circle | |
| HEIGHT = 128 * 2 # pixel grid height | |
| WIDTH = 256 * 2 # pixel grid width | |
| RECT_HALF_WIDTH = 5.0 # rectangle is 1 x 2 in plane coords | |
| RECT_HALF_HEIGHT = 10.0 | |
| GIF_PATH = "llm_art_pca.gif" | |
| FRAME_DURATION_MS = 80 | |
| SEED = 0 | |
| SAMPLES_PER_FRAME = 1024 # how many residuals to sample per frame for PCA | |
| # ------------------------------------------ | |
| def random_unit_vector(dim, device): | |
| v = torch.randn(dim, device=device) | |
| v = v / v.norm() | |
| return v | |
| def main(): | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| device = torch.device("cuda:1") | |
| model = AutoModel.from_pretrained(MODEL_NAME).to(device) | |
| model.eval() | |
| d_model = model.config.hidden_size | |
| # 1–2: two random unit vectors, then orthonormalize to span a plane | |
| u = random_unit_vector(d_model, device) | |
| v_raw = random_unit_vector(d_model, device) | |
| v = v_raw - torch.dot(u, v_raw) * u | |
| v = v / v.norm() | |
| # 3: precompute grid in the (u, v) plane for the rectangle | |
| xs = torch.linspace(-RECT_HALF_WIDTH, RECT_HALF_WIDTH, WIDTH, device=device) | |
| ys = torch.linspace(-RECT_HALF_HEIGHT, RECT_HALF_HEIGHT, HEIGHT, device=device) | |
| grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") # (H, W) | |
| plane_offsets = grid_x[..., None] * u + grid_y[..., None] * v # (H, W, d_model) | |
| frames = [] | |
| with torch.no_grad(): | |
| # ------------ PASS 1: collect samples for PCA ------------ | |
| sample_list = [] | |
| for frame_idx in tqdm(range(NUM_FRAMES), desc="Pass 1/2 (collect PCA samples)"): | |
| theta = 2.0 * math.pi * frame_idx / NUM_FRAMES | |
| # point on the unit circle in the (u, v) plane | |
| center_vec = math.cos(theta) * u + math.sin(theta) * v # (d_model,) | |
| # center of rectangle + all grid offsets | |
| points = center_vec + plane_offsets # (H, W, d_model) | |
| points_flat = points.reshape(-1, d_model) # (H*W, d_model) | |
| inputs_embeds = points_flat.unsqueeze(1) # (B, 1, d_model) | |
| outputs = model(inputs_embeds=inputs_embeds, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states # tuple, len = num_layers + 1 | |
| mid_layer_idx = len(hidden_states) // 2 | |
| residual_vectors = hidden_states[mid_layer_idx][:, 0, :] # (B, d_model) | |
| # sample a subset of residuals for PCA | |
| b = residual_vectors.size(0) | |
| k = min(SAMPLES_PER_FRAME, b) | |
| idx = torch.randint(0, b, (k,), device=device) | |
| sample = residual_vectors[idx].cpu() | |
| sample_list.append(sample) | |
| samples = torch.cat(sample_list, dim=0) # (N_s, d_model) | |
| print(f"Collected {samples.size(0)} samples for PCA") | |
| # ------------ PCA on sampled residuals ------------ | |
| # center the data | |
| mean_vec = samples.mean(dim=0) # (d_model,) | |
| X_centered = samples - mean_vec # (N_s, d_model) | |
| # move to device for PCA | |
| Xc_device = X_centered.to(device) | |
| # low-rank PCA: we only need first 3 principal components | |
| # Xc_device has shape (N_s, d_model) | |
| # V has shape (d_model, d_model); we take first 3 PCs | |
| U, S, V = torch.pca_lowrank(Xc_device, q=3, center=False) | |
| pcs = V[:, :3] # (d_model, 3) | |
| # use sample projections to set global min/max for each PC | |
| sample_proj = Xc_device @ pcs # (N_s, 3) | |
| sample_proj_cpu = sample_proj.cpu() | |
| mins = sample_proj_cpu.min(dim=0).values # (3,) | |
| maxs = sample_proj_cpu.max(dim=0).values # (3,) | |
| # push PCA params to device for pass 2 | |
| mean_vec = mean_vec.to(device) | |
| pcs = pcs.to(device) | |
| mins = mins.to(device) | |
| maxs = maxs.to(device) | |
| # ------------ PASS 2: generate frames using PCA projections ------------ | |
| for frame_idx in tqdm(range(NUM_FRAMES), desc="Pass 2/2 (render frames)"): | |
| theta = 2.0 * math.pi * frame_idx / NUM_FRAMES | |
| center_vec = math.cos(theta) * u + math.sin(theta) * v # (d_model,) | |
| points = center_vec + plane_offsets # (H, W, d_model) | |
| points_flat = points.reshape(-1, d_model) # (H*W, d_model) | |
| inputs_embeds = points_flat.unsqueeze(1) # (B, 1, d_model) | |
| outputs = model(inputs_embeds=inputs_embeds, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states | |
| mid_layer_idx = len(hidden_states) // 2 | |
| residual_vectors = hidden_states[mid_layer_idx][:, 0, :] # (B, d_model) | |
| # project onto first 3 PCs | |
| centered = residual_vectors - mean_vec # (B, d_model) | |
| proj = centered @ pcs # (B, 3) | |
| # normalize each PC to [0, 1] using global min/max, then map to [0, 255] | |
| denom = (maxs - mins) | |
| denom[denom == 0] = 1e-8 | |
| proj_norm = (proj - mins) / denom | |
| proj_norm = proj_norm.clamp(0.0, 1.0) | |
| rgb = (proj_norm * 255.0).clamp(0.0, 255.0) | |
| rgb_np = rgb.reshape(HEIGHT, WIDTH, 3).detach().cpu().numpy().astype(np.uint8) | |
| img = Image.fromarray(rgb_np, mode="RGB") | |
| frames.append(img) | |
| # Looping GIF; GIF itself loops, and the path is circular, so start==end visually | |
| frames[0].save( | |
| GIF_PATH, | |
| save_all=True, | |
| append_images=frames[1:], | |
| loop=0, | |
| duration=FRAME_DURATION_MS, | |
| ) | |
| print(f"Saved GIF to {GIF_PATH}") | |
| if __name__ == "__main__": | |
| main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment