Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active February 3, 2026 03:58
Show Gist options
  • Select an option

  • Save sayakpaul/6e6883db921149a87d35cfde4b4dd5d8 to your computer and use it in GitHub Desktop.

Select an option

Save sayakpaul/6e6883db921149a87d35cfde4b4dd5d8 to your computer and use it in GitHub Desktop.
Benchmarking script for NVFP4 with TorchAO. Pair coded with Claude Code.
"""
NVFP4 Quantization Benchmark for Flux.1-Dev on Modal B200 GPU
This script benchmarks NVFP4 (4-bit floating point) quantization performance
for the Flux.1-Dev image generation model on NVIDIA B200 GPUs.
B200 GPUs have native FP4 Tensor Core support, making them ideal for
this workload.
Usage:
# No quantization, no compilation
modal run flux_nvfp4_benchmark.py --no-compile
# Weight-only NVFP4 quantization
modal run flux_nvfp4_benchmark.py --quant weight-only --compile
# Dynamic NVFP4 quantization (activations + weights)
modal run flux_nvfp4_benchmark.py --quant dynamic --compile
# With custom batch size
modal run flux_nvfp4_benchmark.py --quant weight-only --compile --batch-size 2
"""
from pathlib import Path
import modal
# Cache configuration
CACHE_DIR = Path("/cache")
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)
volumes = {CACHE_DIR: cache_volume}
# Define the container image with all required dependencies
image = (
modal.Image.from_registry(
"nvidia/cuda:12.9.1-devel-ubuntu24.04",
add_python="3.12",
)
.entrypoint([])
.env({"HF_XET_HIGH_PERFORMANCE": "1"})
.uv_pip_install(
"torch",
"numpy",
extra_options="--index-url https://download.pytorch.org/whl/cu129",
)
.uv_pip_install(
"torchao",
extra_options="--pre --index-url https://download.pytorch.org/whl/nightly/cu129",
)
.uv_pip_install(
"transformers",
"accelerate",
"sentencepiece",
"protobuf",
"huggingface_hub[hf_xet]",
)
.apt_install("git")
.uv_pip_install(
"diffusers @ git+https://github.com/huggingface/diffusers.git@flux-contiguous",
)
)
app = modal.App("flux-nvfp4-benchmark", image=image)
@app.function(
gpu="B200",
timeout=3600,
secrets=[modal.Secret.from_name("huggingface-secret")],
volumes=volumes,
)
def run_flux_benchmark(
quant_mode: str | None = None,
use_compile: bool = False,
batch_size: int = 1,
num_inference_steps: int = 28,
warmup_steps: int = 5,
warmup_iterations: int = 2,
):
"""
Run Flux.1-Dev benchmark with optional NVFP4 quantization.
Args:
quant_mode: Quantization mode - None, "weight-only", or "dynamic"
use_compile: Whether to use torch.compile
batch_size: Number of images to generate per inference
num_inference_steps: Number of diffusion steps for benchmark
warmup_steps: Number of diffusion steps for warmup (fewer for speed)
warmup_iterations: Number of warmup iterations
"""
import functools
import gc
import json
import os
from io import BytesIO
import torch
import torch.utils.benchmark as benchmark
from diffusers import FluxPipeline
# Set HF cache directory
os.environ["HF_HOME"] = str(CACHE_DIR)
os.environ["HF_HUB_CACHE"] = str(CACHE_DIR)
# Print GPU info
print("=" * 70)
print("FLUX.1-Dev NVFP4 Benchmark")
print("=" * 70)
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Quantization mode: {quant_mode}")
print(f"Torch compile: {use_compile}")
print(f"Batch size: {batch_size}")
print("-" * 70)
# Helper functions
def get_pipe_kwargs(num_steps=28, bs=1):
return {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": num_steps,
"max_sequence_length": 512,
"num_images_per_prompt": bs,
"generator": torch.manual_seed(0),
}
def run_inference(pipe, pipe_kwargs):
return pipe(**pipe_kwargs)
# Load the model
print("\nLoading Flux.1-Dev model...")
torch.cuda.reset_peak_memory_stats()
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
cache_dir=CACHE_DIR,
).to("cuda")
# Commit the volume to persist cached model
cache_volume.commit()
# Apply quantization if requested
if quant_mode is not None:
print(f"\nApplying {quant_mode} NVFP4 quantization...")
from torchao.quantization import quantize_
from torchao.prototype.mx_formats.inference_workflow import (
NVFP4DynamicActivationNVFP4WeightConfig,
NVFP4WeightOnlyConfig,
)
if quant_mode == "weight-only":
config = NVFP4WeightOnlyConfig(
use_dynamic_per_tensor_scale=True,
)
elif quant_mode == "dynamic":
config = NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=True,
use_triton_kernel=True,
)
else:
raise ValueError(f"Unknown quant_mode: {quant_mode}")
quantize_(pipe.transformer, config=config)
gc.collect()
torch.cuda.empty_cache()
# Apply torch.compile if requested
if use_compile:
print("\nApplying torch.compile to transformer...")
pipe.transformer.compile_repeated_blocks(fullgraph=True)
# Warmup runs
print(f"\nRunning {warmup_iterations} warmup iterations ({warmup_steps} steps each)...")
torch.cuda.reset_peak_memory_stats()
warmup_kwargs = get_pipe_kwargs(num_steps=warmup_steps, bs=batch_size)
for i in range(warmup_iterations):
print(f" Warmup {i + 1}/{warmup_iterations}")
_ = run_inference(pipe, warmup_kwargs)
torch.cuda.synchronize()
# Benchmark
print(f"\nRunning benchmark ({num_inference_steps} steps, batch size {batch_size})...")
torch.cuda.reset_peak_memory_stats()
benchmark_kwargs = get_pipe_kwargs(num_steps=num_inference_steps, bs=batch_size)
inference_func = functools.partial(run_inference, pipe, benchmark_kwargs)
t0 = benchmark.Timer(
stmt="func()",
globals={"func": inference_func},
num_threads=torch.get_num_threads(),
)
measurement = t0.blocked_autorange(min_run_time=5.0)
latency = float(f"{measurement.mean:.3f}")
torch.cuda.synchronize()
peak_memory_gb = torch.cuda.max_memory_allocated() / 1e9
# Generate final image for output
print("\nGenerating output image...")
final_kwargs = get_pipe_kwargs(num_steps=num_inference_steps, bs=1)
result = pipe(**final_kwargs)
output_image = result.images[0]
# Compile results
results = {
"quant_mode": quant_mode,
"compile": use_compile,
"batch_size": batch_size,
"latency_seconds": latency,
"memory_gb": round(peak_memory_gb, 2),
}
# Print summary
print("\n" + "=" * 70)
print("BENCHMARK RESULTS")
print("=" * 70)
print(f"Quantization: {quant_mode}")
print(f"Torch compile: {use_compile}")
print(f"Batch size: {batch_size}")
print(f"Latency: {latency:.3f} seconds")
print(f"Peak Memory: {peak_memory_gb:.2f} GB")
print("=" * 70)
# Serialize image to bytes
img_buffer = BytesIO()
output_image.save(img_buffer, format="PNG")
img_bytes = img_buffer.getvalue()
return {
"results_json": json.dumps(results, indent=2),
"image_bytes": img_bytes,
}
@app.local_entrypoint()
def main(
quant: str | None = None,
compile: bool = False,
batch_size: int = 1,
steps: int = 28,
):
"""
Entry point when running with `modal run`.
Args:
quant: Quantization mode - None, "weight-only", or "dynamic"
compile: Whether to use torch.compile
batch_size: Number of images to generate per inference
steps: Number of inference steps
"""
import json
from pathlib import Path
print("=" * 70)
print("Starting Flux.1-Dev NVFP4 Benchmark on Modal B200")
print("=" * 70)
print(f"Quantization: {quant}")
print(f"Compile: {compile}")
print(f"Batch size: {batch_size}")
print(f"Steps: {steps}")
print("=" * 70)
# Run benchmark
output = run_flux_benchmark.remote(
quant_mode=quant,
use_compile=compile,
batch_size=batch_size,
num_inference_steps=steps,
)
# Create output filename based on config
compile_str = "compiled" if compile else "eager"
quant_str = quant if quant else "no_quant"
base_name = f"flux_benchmark_{quant_str}_{compile_str}_bs{batch_size}"
# Save results JSON
results_path = Path(f"{base_name}_results.json")
results_path.write_text(output["results_json"])
print(f"\nResults saved to: {results_path}")
# Save output image
image_path = Path(f"{base_name}_output.png")
image_path.write_bytes(output["image_bytes"])
print(f"Image saved to: {image_path}")
# Print summary
results = json.loads(output["results_json"])
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"Quantization: {results['quant_mode']}")
print(f"Compile: {results['compile']}")
print(f"Batch size: {results['batch_size']}")
print(f"Latency: {results['latency_seconds']:.3f}s")
print(f"Peak Memory: {results['memory_gb']:.2f} GB")
print("=" * 70)

The script above uses Modal to do the benchmarking as it helps to keep all the dependencies in one place.

Some results are below.

image

batch_size: 1

quant_mode compile batch_size latency_seconds memory_gb
dynamic True 1 2.611 21.25
True 1 2.835 38.34
weight-only True 1 3.23 21.24

batch_size: 4

quant_mode compile batch_size latency_seconds memory_gb
dynamic True 4 8.402 27.29
True 4 10.453 44.39
weight-only True 4 11.438 27.3

batch_size: 8

quant_mode compile batch_size latency_seconds memory_gb
dynamic True 8 16.249 35.91
True 8 21.668 53
weight-only True 8 21.706 35.92

batch_size: 16

quant_mode compile batch_size latency_seconds memory_gb
dynamic True 16 31.83 48.84
weight-only True 16 42.041 48.83
True 16 42.699 65.93

batch_size: 32

quant_mode compile batch_size latency_seconds memory_gb
dynamic True 32 67.532 74.69
weight-only True 32 83.773 74.69
True 32 84.336 91.78

There is some ongoing discussions around bmm with NVFP4 which is worth checking out here: pytorch/ao#3783

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment