Created
February 6, 2026 10:36
-
-
Save rolux/f5b9ffd05377a8d8e8061b66ddf0bcb1 to your computer and use it in GitHub Desktop.
How to gradually overwrite an image model with its own output
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "mflux", | |
| # "mlx", | |
| # "numpy", | |
| # "pillow", | |
| # "protobuf", | |
| # "sentencepiece" | |
| # ] | |
| # /// | |
| """ | |
| download the model: | |
| mflux-save --model dev --path models/dev | |
| invoke with: | |
| uv run transformer.py seed | |
| """ | |
| # patch weight loader | |
| import importlib.util | |
| filename = importlib.util.find_spec("mflux.models.common.weights.loading.weight_loader").origin | |
| with open(filename, "r") as f: | |
| source = f.read() | |
| with open(filename, "w") as f: | |
| f.write(source.replace( | |
| 'if quantization_level_str is not None', | |
| 'if quantization_level_str not in (None, "None")' | |
| )) | |
| import json | |
| import os | |
| import sys | |
| import mlx.core as mx | |
| import numpy as np | |
| from PIL import Image | |
| from mflux.models.common.config.model_config import ModelConfig | |
| from mflux.models.flux.variants.txt2img.flux import Flux1 | |
| ### Model ### | |
| MODEL_NAME = None | |
| OUTPUT_DIR = None | |
| def load_model(model_name="dev", model_path="models/dev"): | |
| global MODEL_NAME, OUTPUT_DIR | |
| MODEL_NAME = model_name | |
| OUTPUT_DIR = f"outputs/{model_name}" | |
| return Flux1( | |
| model_config=ModelConfig.from_name(model_name=model_name), | |
| model_path=model_path, | |
| quantize=None | |
| ) | |
| def render_image( | |
| filename, | |
| prompt, | |
| seed=42, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=20, | |
| guidance=4.0 | |
| ): | |
| if os.path.exists(filename): | |
| return | |
| print(f"rendering {filename}") | |
| image = flux.generate_image( | |
| prompt=prompt, | |
| seed=seed, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| guidance=guidance | |
| ) | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| image.save(filename) | |
| ### Utils ### | |
| def get_frames(dirname): | |
| filename = f"{dirname}.txt" | |
| print(f"writing {filename}") | |
| path = os.path.basename(dirname) | |
| with open(filename, "w") as f: | |
| f.write("\n".join([ | |
| f"file '{path}/{f}'" for f in sorted(os.listdir(dirname)) if f.endswith(".png") | |
| ])) | |
| return filename | |
| def get_layer_list(filename, model_path="models/dev"): | |
| if os.path.exists(filename): | |
| return | |
| print(f"writing {filename}") | |
| with open(f"{model_path}/transformer/model.safetensors.index.json", "r") as f: | |
| data = json.load(f) | |
| lines = [] | |
| for key in data["weight_map"]: | |
| parent, name = get_parent_and_name(key) | |
| obj = parent[name] | |
| lines.append(f"{key}, {obj.shape}, {obj.dtype}") | |
| with open(filename, "w") as f: | |
| f.write("\n".join(lines)) | |
| def get_parent_and_name(string): | |
| obj = flux.transformer | |
| keys = string.split(".") | |
| for key in keys[:-1]: | |
| obj = obj[int(key)] if key.isdigit() else getattr(obj, key) | |
| return obj, keys[-1] | |
| def render_grid(filename, image_filenames, size): | |
| if os.path.exists(filename): | |
| return | |
| print(f"rendering {filename}") | |
| image = Image.new("RGB", (len(image_filenames) * size, size)) | |
| for i, image_filename in enumerate(image_filenames): | |
| image.paste(Image.open(image_filename).resize((size, size), Image.LANCZOS), (i * size, 0)) | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| image.save(filename) | |
| ### Transformer ### | |
| def render_transformer_layer(filename, layer, mode="RGB"): | |
| if os.path.exists(filename): | |
| return | |
| print(f"rendering {filename}") | |
| parent, name = get_parent_and_name(layer) | |
| w = np.array(parent[name].astype(mx.float32)) | |
| w_min, w_max = float(w.min()), float(w.max()) | |
| w = (w - w_min) * (255.0 / (w_max - w_min)) | |
| w = np.clip(w, 0, 255).astype(np.uint8) | |
| if mode == "RGB": | |
| w = w.reshape(3072, 1024, 3) | |
| image = Image.fromarray(w, mode=mode).resize((1024, 1024), Image.LANCZOS) | |
| elif mode == "L": | |
| image = Image.fromarray(w, mode=mode) | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| image.save(filename) | |
| def patch_transformer_with_image(layer, filename, strength, mode="RGB"): | |
| if mode == "RGB": | |
| image = Image.open(filename).resize((1024, 3072), Image.LANCZOS) | |
| image = np.array(image, dtype=np.float16) / 255 - 0.5 | |
| image = image.reshape(3072, 3072) | |
| elif mode == "L": | |
| image = Image.open(filename).resize((3072, 3072), Image.LANCZOS).convert("L") | |
| image = np.array(image, dtype=np.float16) / 255 - 0.5 | |
| parent, name = get_parent_and_name(layer) | |
| original = mx.array(parent[name]) | |
| parent[name] = parent[name] + strength * mx.array(image) | |
| return original | |
| def render_transformer( | |
| dirname="transformer", | |
| prompt="transformer", | |
| seed=42, | |
| layer="transformer_blocks.0.attn.to_out.0.weight", | |
| n=60, | |
| strength=0.1, | |
| mode="RGB", | |
| size=2160 | |
| ): | |
| dirname = f"{OUTPUT_DIR}/{dirname}/{prompt},{seed},{layer},{strength:.06f}" | |
| for i in range(n): | |
| filename_model = f"{dirname}/model/{i:08d}.png" | |
| render_transformer_layer(filename_model, layer, mode=mode) | |
| filename_image = f"{dirname}/image/{i:08d}.png" | |
| render_image(filename_image, prompt, seed) | |
| filename_video = f"{dirname}/video/{i:08d}.png" | |
| render_grid(filename_video, [filename_model, filename_image], size) | |
| patch_transformer_with_image(layer, filename_image, strength, mode=mode) | |
| filename_frames = get_frames(f"{dirname}/video") | |
| os.system( | |
| f'ffmpeg -y -r 1 -f concat -safe 0 -i "{filename_frames}" ' | |
| f'-c:v libx264 -pix_fmt yuv420p -crf 18 "{dirname}.mp4"' | |
| ) | |
| if __name__ == "__main__": | |
| seed = sys.argv[1] | |
| flux = load_model() | |
| get_layer_list(f"layers_{MODEL_NAME}.txt") | |
| render_transformer(seed=seed) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment