Skip to content

Instantly share code, notes, and snippets.

@rolux
Created February 6, 2026 10:36
Show Gist options
  • Select an option

  • Save rolux/f5b9ffd05377a8d8e8061b66ddf0bcb1 to your computer and use it in GitHub Desktop.

Select an option

Save rolux/f5b9ffd05377a8d8e8061b66ddf0bcb1 to your computer and use it in GitHub Desktop.
How to gradually overwrite an image model with its own output
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "mflux",
# "mlx",
# "numpy",
# "pillow",
# "protobuf",
# "sentencepiece"
# ]
# ///
"""
download the model:
mflux-save --model dev --path models/dev
invoke with:
uv run transformer.py seed
"""
# patch weight loader
import importlib.util
filename = importlib.util.find_spec("mflux.models.common.weights.loading.weight_loader").origin
with open(filename, "r") as f:
source = f.read()
with open(filename, "w") as f:
f.write(source.replace(
'if quantization_level_str is not None',
'if quantization_level_str not in (None, "None")'
))
import json
import os
import sys
import mlx.core as mx
import numpy as np
from PIL import Image
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.variants.txt2img.flux import Flux1
### Model ###
MODEL_NAME = None
OUTPUT_DIR = None
def load_model(model_name="dev", model_path="models/dev"):
global MODEL_NAME, OUTPUT_DIR
MODEL_NAME = model_name
OUTPUT_DIR = f"outputs/{model_name}"
return Flux1(
model_config=ModelConfig.from_name(model_name=model_name),
model_path=model_path,
quantize=None
)
def render_image(
filename,
prompt,
seed=42,
width=1024,
height=1024,
num_inference_steps=20,
guidance=4.0
):
if os.path.exists(filename):
return
print(f"rendering {filename}")
image = flux.generate_image(
prompt=prompt,
seed=seed,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance=guidance
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
### Utils ###
def get_frames(dirname):
filename = f"{dirname}.txt"
print(f"writing {filename}")
path = os.path.basename(dirname)
with open(filename, "w") as f:
f.write("\n".join([
f"file '{path}/{f}'" for f in sorted(os.listdir(dirname)) if f.endswith(".png")
]))
return filename
def get_layer_list(filename, model_path="models/dev"):
if os.path.exists(filename):
return
print(f"writing {filename}")
with open(f"{model_path}/transformer/model.safetensors.index.json", "r") as f:
data = json.load(f)
lines = []
for key in data["weight_map"]:
parent, name = get_parent_and_name(key)
obj = parent[name]
lines.append(f"{key}, {obj.shape}, {obj.dtype}")
with open(filename, "w") as f:
f.write("\n".join(lines))
def get_parent_and_name(string):
obj = flux.transformer
keys = string.split(".")
for key in keys[:-1]:
obj = obj[int(key)] if key.isdigit() else getattr(obj, key)
return obj, keys[-1]
def render_grid(filename, image_filenames, size):
if os.path.exists(filename):
return
print(f"rendering {filename}")
image = Image.new("RGB", (len(image_filenames) * size, size))
for i, image_filename in enumerate(image_filenames):
image.paste(Image.open(image_filename).resize((size, size), Image.LANCZOS), (i * size, 0))
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
### Transformer ###
def render_transformer_layer(filename, layer, mode="RGB"):
if os.path.exists(filename):
return
print(f"rendering {filename}")
parent, name = get_parent_and_name(layer)
w = np.array(parent[name].astype(mx.float32))
w_min, w_max = float(w.min()), float(w.max())
w = (w - w_min) * (255.0 / (w_max - w_min))
w = np.clip(w, 0, 255).astype(np.uint8)
if mode == "RGB":
w = w.reshape(3072, 1024, 3)
image = Image.fromarray(w, mode=mode).resize((1024, 1024), Image.LANCZOS)
elif mode == "L":
image = Image.fromarray(w, mode=mode)
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
def patch_transformer_with_image(layer, filename, strength, mode="RGB"):
if mode == "RGB":
image = Image.open(filename).resize((1024, 3072), Image.LANCZOS)
image = np.array(image, dtype=np.float16) / 255 - 0.5
image = image.reshape(3072, 3072)
elif mode == "L":
image = Image.open(filename).resize((3072, 3072), Image.LANCZOS).convert("L")
image = np.array(image, dtype=np.float16) / 255 - 0.5
parent, name = get_parent_and_name(layer)
original = mx.array(parent[name])
parent[name] = parent[name] + strength * mx.array(image)
return original
def render_transformer(
dirname="transformer",
prompt="transformer",
seed=42,
layer="transformer_blocks.0.attn.to_out.0.weight",
n=60,
strength=0.1,
mode="RGB",
size=2160
):
dirname = f"{OUTPUT_DIR}/{dirname}/{prompt},{seed},{layer},{strength:.06f}"
for i in range(n):
filename_model = f"{dirname}/model/{i:08d}.png"
render_transformer_layer(filename_model, layer, mode=mode)
filename_image = f"{dirname}/image/{i:08d}.png"
render_image(filename_image, prompt, seed)
filename_video = f"{dirname}/video/{i:08d}.png"
render_grid(filename_video, [filename_model, filename_image], size)
patch_transformer_with_image(layer, filename_image, strength, mode=mode)
filename_frames = get_frames(f"{dirname}/video")
os.system(
f'ffmpeg -y -r 1 -f concat -safe 0 -i "{filename_frames}" '
f'-c:v libx264 -pix_fmt yuv420p -crf 18 "{dirname}.mp4"'
)
if __name__ == "__main__":
seed = sys.argv[1]
flux = load_model()
get_layer_list(f"layers_{MODEL_NAME}.txt")
render_transformer(seed=seed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment