|
""" |
|
NVFP4 Quantization Benchmark for Flux.1-Dev on Modal B200 GPU |
|
|
|
This script benchmarks NVFP4 (4-bit floating point) quantization performance |
|
for the Flux.1-Dev image generation model on NVIDIA B200 GPUs. |
|
|
|
B200 GPUs have native FP4 Tensor Core support, making them ideal for |
|
this workload. |
|
|
|
Usage: |
|
# No quantization, no compilation |
|
modal run flux_nvfp4_benchmark.py --no-compile |
|
|
|
# Weight-only NVFP4 quantization |
|
modal run flux_nvfp4_benchmark.py --quant weight-only --compile |
|
|
|
# Dynamic NVFP4 quantization (activations + weights) |
|
modal run flux_nvfp4_benchmark.py --quant dynamic --compile |
|
|
|
# With custom batch size |
|
modal run flux_nvfp4_benchmark.py --quant weight-only --compile --batch-size 2 |
|
""" |
|
|
|
from pathlib import Path |
|
|
|
import modal |
|
|
|
# Cache configuration |
|
CACHE_DIR = Path("/cache") |
|
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) |
|
volumes = {CACHE_DIR: cache_volume} |
|
|
|
# Define the container image with all required dependencies |
|
image = ( |
|
modal.Image.from_registry( |
|
"nvidia/cuda:12.9.1-devel-ubuntu24.04", |
|
add_python="3.12", |
|
) |
|
.entrypoint([]) |
|
.env({"HF_XET_HIGH_PERFORMANCE": "1"}) |
|
.uv_pip_install( |
|
"torch", |
|
"numpy", |
|
extra_options="--index-url https://download.pytorch.org/whl/cu129", |
|
) |
|
.uv_pip_install( |
|
"torchao", |
|
extra_options="--pre --index-url https://download.pytorch.org/whl/nightly/cu129", |
|
) |
|
.uv_pip_install( |
|
"transformers", |
|
"accelerate", |
|
"sentencepiece", |
|
"protobuf", |
|
"huggingface_hub[hf_xet]", |
|
) |
|
.apt_install("git") |
|
.uv_pip_install( |
|
"diffusers @ git+https://github.com/huggingface/diffusers.git@flux-contiguous", |
|
) |
|
) |
|
|
|
app = modal.App("flux-nvfp4-benchmark", image=image) |
|
|
|
|
|
@app.function( |
|
gpu="B200", |
|
timeout=3600, |
|
secrets=[modal.Secret.from_name("huggingface-secret")], |
|
volumes=volumes, |
|
) |
|
def run_flux_benchmark( |
|
quant_mode: str | None = None, |
|
use_compile: bool = False, |
|
batch_size: int = 1, |
|
num_inference_steps: int = 28, |
|
warmup_steps: int = 5, |
|
warmup_iterations: int = 2, |
|
): |
|
""" |
|
Run Flux.1-Dev benchmark with optional NVFP4 quantization. |
|
|
|
Args: |
|
quant_mode: Quantization mode - None, "weight-only", or "dynamic" |
|
use_compile: Whether to use torch.compile |
|
batch_size: Number of images to generate per inference |
|
num_inference_steps: Number of diffusion steps for benchmark |
|
warmup_steps: Number of diffusion steps for warmup (fewer for speed) |
|
warmup_iterations: Number of warmup iterations |
|
""" |
|
import functools |
|
import gc |
|
import json |
|
import os |
|
from io import BytesIO |
|
|
|
import torch |
|
import torch.utils.benchmark as benchmark |
|
from diffusers import FluxPipeline |
|
|
|
# Set HF cache directory |
|
os.environ["HF_HOME"] = str(CACHE_DIR) |
|
os.environ["HF_HUB_CACHE"] = str(CACHE_DIR) |
|
|
|
# Print GPU info |
|
print("=" * 70) |
|
print("FLUX.1-Dev NVFP4 Benchmark") |
|
print("=" * 70) |
|
print(f"GPU: {torch.cuda.get_device_name(0)}") |
|
print(f"Quantization mode: {quant_mode}") |
|
print(f"Torch compile: {use_compile}") |
|
print(f"Batch size: {batch_size}") |
|
print("-" * 70) |
|
|
|
# Helper functions |
|
def get_pipe_kwargs(num_steps=28, bs=1): |
|
return { |
|
"prompt": "A cat holding a sign that says hello world", |
|
"height": 1024, |
|
"width": 1024, |
|
"guidance_scale": 3.5, |
|
"num_inference_steps": num_steps, |
|
"max_sequence_length": 512, |
|
"num_images_per_prompt": bs, |
|
"generator": torch.manual_seed(0), |
|
} |
|
|
|
def run_inference(pipe, pipe_kwargs): |
|
return pipe(**pipe_kwargs) |
|
|
|
# Load the model |
|
print("\nLoading Flux.1-Dev model...") |
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
torch_dtype=torch.bfloat16, |
|
cache_dir=CACHE_DIR, |
|
).to("cuda") |
|
|
|
# Commit the volume to persist cached model |
|
cache_volume.commit() |
|
|
|
# Apply quantization if requested |
|
if quant_mode is not None: |
|
print(f"\nApplying {quant_mode} NVFP4 quantization...") |
|
from torchao.quantization import quantize_ |
|
from torchao.prototype.mx_formats.inference_workflow import ( |
|
NVFP4DynamicActivationNVFP4WeightConfig, |
|
NVFP4WeightOnlyConfig, |
|
) |
|
|
|
if quant_mode == "weight-only": |
|
config = NVFP4WeightOnlyConfig( |
|
use_dynamic_per_tensor_scale=True, |
|
) |
|
elif quant_mode == "dynamic": |
|
config = NVFP4DynamicActivationNVFP4WeightConfig( |
|
use_dynamic_per_tensor_scale=True, |
|
use_triton_kernel=True, |
|
) |
|
else: |
|
raise ValueError(f"Unknown quant_mode: {quant_mode}") |
|
|
|
quantize_(pipe.transformer, config=config) |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
# Apply torch.compile if requested |
|
if use_compile: |
|
print("\nApplying torch.compile to transformer...") |
|
pipe.transformer.compile_repeated_blocks(fullgraph=True) |
|
|
|
# Warmup runs |
|
print(f"\nRunning {warmup_iterations} warmup iterations ({warmup_steps} steps each)...") |
|
torch.cuda.reset_peak_memory_stats() |
|
warmup_kwargs = get_pipe_kwargs(num_steps=warmup_steps, bs=batch_size) |
|
|
|
for i in range(warmup_iterations): |
|
print(f" Warmup {i + 1}/{warmup_iterations}") |
|
_ = run_inference(pipe, warmup_kwargs) |
|
torch.cuda.synchronize() |
|
|
|
# Benchmark |
|
print(f"\nRunning benchmark ({num_inference_steps} steps, batch size {batch_size})...") |
|
torch.cuda.reset_peak_memory_stats() |
|
benchmark_kwargs = get_pipe_kwargs(num_steps=num_inference_steps, bs=batch_size) |
|
inference_func = functools.partial(run_inference, pipe, benchmark_kwargs) |
|
|
|
t0 = benchmark.Timer( |
|
stmt="func()", |
|
globals={"func": inference_func}, |
|
num_threads=torch.get_num_threads(), |
|
) |
|
|
|
measurement = t0.blocked_autorange(min_run_time=5.0) |
|
latency = float(f"{measurement.mean:.3f}") |
|
|
|
torch.cuda.synchronize() |
|
peak_memory_gb = torch.cuda.max_memory_allocated() / 1e9 |
|
|
|
# Generate final image for output |
|
print("\nGenerating output image...") |
|
final_kwargs = get_pipe_kwargs(num_steps=num_inference_steps, bs=1) |
|
result = pipe(**final_kwargs) |
|
output_image = result.images[0] |
|
|
|
# Compile results |
|
results = { |
|
"quant_mode": quant_mode, |
|
"compile": use_compile, |
|
"batch_size": batch_size, |
|
"latency_seconds": latency, |
|
"memory_gb": round(peak_memory_gb, 2), |
|
} |
|
|
|
# Print summary |
|
print("\n" + "=" * 70) |
|
print("BENCHMARK RESULTS") |
|
print("=" * 70) |
|
print(f"Quantization: {quant_mode}") |
|
print(f"Torch compile: {use_compile}") |
|
print(f"Batch size: {batch_size}") |
|
print(f"Latency: {latency:.3f} seconds") |
|
print(f"Peak Memory: {peak_memory_gb:.2f} GB") |
|
print("=" * 70) |
|
|
|
# Serialize image to bytes |
|
img_buffer = BytesIO() |
|
output_image.save(img_buffer, format="PNG") |
|
img_bytes = img_buffer.getvalue() |
|
|
|
return { |
|
"results_json": json.dumps(results, indent=2), |
|
"image_bytes": img_bytes, |
|
} |
|
|
|
|
|
@app.local_entrypoint() |
|
def main( |
|
quant: str | None = None, |
|
compile: bool = False, |
|
batch_size: int = 1, |
|
steps: int = 28, |
|
): |
|
""" |
|
Entry point when running with `modal run`. |
|
|
|
Args: |
|
quant: Quantization mode - None, "weight-only", or "dynamic" |
|
compile: Whether to use torch.compile |
|
batch_size: Number of images to generate per inference |
|
steps: Number of inference steps |
|
""" |
|
import json |
|
from pathlib import Path |
|
|
|
print("=" * 70) |
|
print("Starting Flux.1-Dev NVFP4 Benchmark on Modal B200") |
|
print("=" * 70) |
|
print(f"Quantization: {quant}") |
|
print(f"Compile: {compile}") |
|
print(f"Batch size: {batch_size}") |
|
print(f"Steps: {steps}") |
|
print("=" * 70) |
|
|
|
# Run benchmark |
|
output = run_flux_benchmark.remote( |
|
quant_mode=quant, |
|
use_compile=compile, |
|
batch_size=batch_size, |
|
num_inference_steps=steps, |
|
) |
|
|
|
# Create output filename based on config |
|
compile_str = "compiled" if compile else "eager" |
|
quant_str = quant if quant else "no_quant" |
|
base_name = f"flux_benchmark_{quant_str}_{compile_str}_bs{batch_size}" |
|
|
|
# Save results JSON |
|
results_path = Path(f"{base_name}_results.json") |
|
results_path.write_text(output["results_json"]) |
|
print(f"\nResults saved to: {results_path}") |
|
|
|
# Save output image |
|
image_path = Path(f"{base_name}_output.png") |
|
image_path.write_bytes(output["image_bytes"]) |
|
print(f"Image saved to: {image_path}") |
|
|
|
# Print summary |
|
results = json.loads(output["results_json"]) |
|
print("\n" + "=" * 70) |
|
print("SUMMARY") |
|
print("=" * 70) |
|
print(f"Quantization: {results['quant_mode']}") |
|
print(f"Compile: {results['compile']}") |
|
print(f"Batch size: {results['batch_size']}") |
|
print(f"Latency: {results['latency_seconds']:.3f}s") |
|
print(f"Peak Memory: {results['memory_gb']:.2f} GB") |
|
print("=" * 70) |