Skip to content

Instantly share code, notes, and snippets.

@Nottlespike
Last active December 31, 2025 09:49
Show Gist options
  • Select an option

  • Save Nottlespike/9cea5bfab715a6ead2de687c46e8d296 to your computer and use it in GitHub Desktop.

Select an option

Save Nottlespike/9cea5bfab715a6ead2de687c46e8d296 to your computer and use it in GitHub Desktop.
MLX "NVFP4" vs NVFP4
"""
Accurate comparison of MLX nvfp4 vs NVIDIA NVFP4 implementation.
Key architectural difference:
NVIDIA NVFP4 uses a TWO-LEVEL SCALING strategy:
1. Global FP32 per-tensor scale: s_enc = (6 * 448) / tensor_amax
2. Local E4M3 per-block scale: one scale per 16 elements
MLX appears to use only single-level E4M3 block scaling without the FP32 tensor scale.
This is NOT about signed vs unsigned E4M3 - both use E4M3 for block scales.
The difference is the two-level architecture that enables arbitrary dynamic range.
References:
-----------
[1] NVIDIA. "Pretraining Large Language Models with NVFP4."
arXiv:2509.25149, September 2025.
https://arxiv.org/abs/2509.25149
[2] Alvarez, E., et al. "Introducing NVFP4 for Efficient and Accurate
Low-Precision Inference." NVIDIA Developer Blog, June 2025.
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/
[3] vLLM NVFP4 Implementation (NVIDIA-contributed):
- Core kernels: csrc/quantization/fp4/nvfp4_quant_kernels.cu
- Utilities: csrc/quantization/fp4/nvfp4_utils.cuh
https://github.com/vllm-project/vllm/tree/main/csrc/quantization/fp4
[4] vLLM Marlin NVFP4 Implementation (optimized GPU kernels):
- csrc/quantization/fp4/marlin/marlin_fp4_kernel.cu
https://github.com/vllm-project/vllm/tree/main/csrc/quantization/fp4/marlin
Note: Marlin uses similar two-level scaling but with optimizations for
batched inference. Closer to production usage but not identical to
NVIDIA's reference implementation.
"""
import numpy as np
# =============================================================================
# Format Constants
# =============================================================================
# E2M1 (FP4) format constants
E2M1_MAX = 6.0 # Maximum representable value
E2M1_MIN_NONZERO = 0.5 # Minimum non-zero value (subnormal: 2^0 * 0.5 = 0.5)
E2M1_VALUES = [-6, -4, -3, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 3, 4, 6]
# E4M3 format constants (same for NVIDIA and MLX block scales)
E4M3_MAX = 448.0 # (1 + 7/8) * 2^7 = 1.875 * 256 = 448
E4M3_MIN_POSITIVE = 2**-9 # Smallest positive subnormal
# NVFP4 block size
NVFP4_BLOCK_SIZE = 16 # Elements per block (vs 32 for MXFP4)
def analyze_nvidia_nvfp4():
"""
NVIDIA NVFP4 as specified in arXiv:2509.25149.
Two-level scaling strategy:
1. Global encode scale: s_enc = (6 * 448) / amax_x
- Remaps tensor values into representable range of (FP4 × FP8)
2. Local block scale in E4M3: per-16-element blocks
- s_dec_b = amax_b / 6 (scaled by global scale before E4M3 conversion)
The FP32 global scale provides ARBITRARY dynamic range.
"""
return {
"name": "NVIDIA NVFP4 (Two-Level Scaling)",
"block_size": 16,
"block_scale_format": "E4M3",
"block_scale_max": E4M3_MAX,
"tensor_scale_format": "FP32",
"tensor_scale_range": "Extended to full FP32 range (~1e-38 to ~3.4e38)",
# The global scale formula from the paper
"global_scale_formula": "s_enc = (6 * 448) / tensor_amax",
# Effective per-block range after global scaling
"per_block_max": E2M1_MAX * E4M3_MAX, # 6 * 448 = 2688
# With FP32 tensor scale, effective range is unlimited
"effective_max": float("inf"), # Limited only by FP32
"effective_min": 0, # Can represent arbitrarily small values
}
def analyze_mlx_nvfp4():
"""
MLX's current NVFP4 implementation (based on Metal kernel analysis).
Single-level scaling:
- Uses E4M3 block scales only
- No FP32 per-tensor global scale
This limits the effective dynamic range to FP4 × E4M3.
"""
effective_max = E2M1_MAX * E4M3_MAX # 6 * 448 = 2688
effective_min = E2M1_MIN_NONZERO * E4M3_MIN_POSITIVE # 0.5 * 2^-9
return {
"name": "MLX nvfp4 (Single-Level Scaling)",
"block_size": 16,
"block_scale_format": "E4M3 (signed)",
"block_scale_max": E4M3_MAX,
"tensor_scale_format": "None",
"tensor_scale_range": "N/A",
"per_block_max": effective_max,
"effective_max": effective_max, # 2688
"effective_min": effective_min, # ~0.000977
}
def analyze_mxfp4_comparison():
"""
MXFP4 (OCP Microscaling) for reference - what NVFP4 improves upon.
- Block size: 32 (vs 16 for NVFP4)
- Scale format: UE8M0 (power-of-two only, no mantissa)
- No tensor-level scale
"""
# UE8M0: 8-bit unsigned exponent, range 2^-127 to 2^127
ue8m0_max = 2**127
return {
"name": "MXFP4 (OCP Microscaling)",
"block_size": 32,
"block_scale_format": "UE8M0 (power-of-two)",
"block_scale_max": ue8m0_max,
"tensor_scale_format": "None",
"note": "Power-of-two scales can waste up to 1 binade of dynamic range",
}
def analyze_vllm_marlin_nvfp4():
"""
vLLM's Marlin NVFP4 implementation - optimized for batched inference.
Source: csrc/quantization/fp4/marlin/marlin_fp4_kernel.cu
Key characteristics:
- Uses two-level scaling like NVIDIA reference
- Global scale stored per output channel (not per-tensor)
- Optimized memory access patterns for GPU inference
- Block size: 16 (same as NVIDIA NVFP4)
Differences from NVIDIA reference:
- Scale granularity: per-channel vs per-tensor global scale
- Memory layout: optimized for specific GPU architectures
- Focus on inference rather than training
"""
return {
"name": "vLLM Marlin NVFP4 (Optimized Inference)",
"block_size": 16,
"block_scale_format": "E4M3",
"block_scale_max": E4M3_MAX,
"tensor_scale_format": "FP32 (per-channel)",
"tensor_scale_range": "Extended to full FP32 range",
"global_scale_formula": "Similar to NVIDIA, per output channel",
"note": "Closer to NVIDIA spec than MLX, but optimized for inference",
"source": "https://github.com/vllm-project/vllm/tree/main/csrc/quantization/fp4/marlin",
}
def compute_dynamic_range_db(max_val, min_val):
"""Compute dynamic range in decibels."""
if min_val <= 0 or max_val <= 0:
return float("inf")
return 20 * np.log10(max_val / min_val)
def demonstrate_two_level_scaling():
"""
Demonstrate how NVIDIA's two-level scaling works with a concrete example.
"""
print("\n" + "=" * 80)
print("NVIDIA TWO-LEVEL SCALING: WORKED EXAMPLE")
print("=" * 80)
# Example: tensor with large values (common in LLMs)
tensor_amax = 50000.0
block_amax = 45000.0
print("\nExample tensor with large activations:")
print(f" Tensor absolute max (amax_x): {tensor_amax:,.0f}")
print(f" Block absolute max (amax_b): {block_amax:,.0f}")
# Step 1: Compute global encode scale
s_enc = (E2M1_MAX * E4M3_MAX) / tensor_amax
print("\nStep 1: Global FP32 encode scale")
print(f" s_enc = (6 × 448) / {tensor_amax:,.0f}")
print(f" s_enc = {6 * 448} / {tensor_amax:,.0f}")
print(f" s_enc = {s_enc:.6f}")
# Step 2: Compute local block decode scale
s_dec_b = block_amax / E2M1_MAX
print("\nStep 2: Local block decode scale (before global scaling)")
print(f" s_dec_b = amax_b / 6 = {block_amax:,.0f} / 6 = {s_dec_b:.2f}")
# Step 3: Apply global scale to block scale for E4M3 storage
s_dec_b_scaled = s_dec_b * s_enc
print("\nStep 3: Scale block scale into E4M3 range")
print(" s_dec_b_e4m3 = s_dec_b × s_enc")
print(f" s_dec_b_e4m3 = {s_dec_b:.2f} × {s_enc:.6f}")
print(f" s_dec_b_e4m3 = {s_dec_b_scaled:.4f}")
print(" (This fits in E4M3 range [0, 448] ✓)")
# Step 4: Quantize a value
original_value = 42000.0
# Encode: multiply by local encode scale
s_enc_b = 1 / (s_dec_b_scaled / s_enc) # Reconstruct encode scale
scaled_value = original_value * (E2M1_MAX / block_amax)
quantized = round(scaled_value * 2) / 2 # Round to nearest E2M1
quantized = np.clip(quantized, -E2M1_MAX, E2M1_MAX)
print(f"\nStep 4: Quantize value {original_value:,.0f}")
print(f" Scale into FP4 range: {original_value:,.0f} × (6 / {block_amax:,.0f})")
print(f" Scaled value: {scaled_value:.4f}")
print(f" Quantized to E2M1: {quantized}")
# Step 5: Decode back
decoded = quantized * s_dec_b
print("\nStep 5: Decode back")
print(f" Decoded = {quantized} × {s_dec_b:.2f} = {decoded:,.0f}")
print(
f" Original: {original_value:,.0f}, Error: {abs(original_value - decoded) / original_value * 100:.2f}%"
)
print("\n KEY INSIGHT: The FP32 global scale allows representing values")
print(" far beyond E4M3's 448 limit by rescaling the entire tensor first!")
def main():
print("=" * 80)
print("NVFP4 IMPLEMENTATION COMPARISON: MLX vs NVIDIA")
print("Based on NVIDIA arXiv:2509.25149 'Pretraining LLMs with NVFP4'")
print("=" * 80)
nvidia = analyze_nvidia_nvfp4()
mlx = analyze_mlx_nvfp4()
mxfp4 = analyze_mxfp4_comparison()
marlin = analyze_vllm_marlin_nvfp4()
# MLX Analysis
print("\n" + "-" * 80)
print(f"MLX: {mlx['name']}")
print("-" * 80)
print(f" Block size: {mlx['block_size']} elements")
print(f" Block scale format: {mlx['block_scale_format']}")
print(f" Block scale max: {mlx['block_scale_max']}")
print(f" Tensor-level scale: {mlx['tensor_scale_format']}")
print("\n Per-block representable range:")
print(f" Max: E2M1_max × E4M3_max = 6 × 448 = {mlx['per_block_max']:,}")
print(f" Min: E2M1_min × E4M3_min = 0.5 × 2^-9 ≈ {mlx['effective_min']:.6f}")
mlx_dr = compute_dynamic_range_db(mlx["effective_max"], mlx["effective_min"])
print(f" Dynamic range: {mlx_dr:.1f} dB")
# NVIDIA Analysis
print("\n" + "-" * 80)
print(f"NVIDIA: {nvidia['name']}")
print("-" * 80)
print(f" Block size: {nvidia['block_size']} elements")
print(f" Block scale format: {nvidia['block_scale_format']}")
print(f" Block scale max: {nvidia['block_scale_max']}")
print(f" Tensor-level scale: {nvidia['tensor_scale_format']}")
print(f" Global scale formula: {nvidia['global_scale_formula']}")
print("\n Per-block representable range (before tensor scale):")
print(f" Max: E2M1_max × E4M3_max = 6 × 448 = {nvidia['per_block_max']:,}")
print("\n Effective range (with FP32 tensor scale):")
print(" Extended to FP32 limits (~1.18e-38 to ~3.4e38)")
# MXFP4 Reference
print("\n" + "-" * 80)
print(f"Reference: {mxfp4['name']}")
print("-" * 80)
print(f" Block size: {mxfp4['block_size']} elements")
print(f" Block scale format: {mxfp4['block_scale_format']}")
print(f" Note: {mxfp4['note']}")
# vLLM Marlin Analysis
print("\n" + "-" * 80)
print(f"Comparison: {marlin['name']}")
print("-" * 80)
print(f" Block size: {marlin['block_size']} elements")
print(f" Block scale format: {marlin['block_scale_format']}")
print(f" Block scale max: {marlin['block_scale_max']}")
print(f" Tensor-level scale: {marlin['tensor_scale_format']}")
print(f" Note: {marlin['note']}")
print(f" Source: {marlin['source']}")
# Demonstrate the mechanism
demonstrate_two_level_scaling()
# Key differences summary
print("\n" + "=" * 80)
print("KEY ARCHITECTURAL DIFFERENCES")
print("=" * 80)
print("""
┌─────────────────┬──────────────────┬──────────────────┬──────────────────┐
│ Feature │ MLX nvfp4 │ vLLM Marlin │ NVIDIA NVFP4 │
├─────────────────┼──────────────────┼──────────────────┼──────────────────┤
│ Block size │ 16 │ 16 │ 16 │
│ Block scale │ E4M3 │ E4M3 │ E4M3 │
│ Tensor scale │ ✗ None │ ✓ FP32/channel │ ✓ FP32/tensor │
│ Scaling levels │ 1 (block only) │ 2 (chan + block) │ 2 (tens + block) │
│ Dynamic range │ Fixed (~2688) │ Extended (FP32) │ Extended (FP32) │
│ Use case │ General │ Inference │ Training/Infer │
└─────────────────┴──────────────────┴──────────────────┴──────────────────┘
vLLM Marlin is CLOSER to NVIDIA spec than MLX because it implements two-level
scaling, but uses per-channel rather than per-tensor global scales. This is
an optimization for inference workloads where channel-wise scaling can be
more efficient.
""")
print("IMPLICATIONS FOR MLX:")
print("-" * 80)
print("""
1. WITHOUT FP32 TENSOR SCALE:
- MLX's effective max is 6 × 448 = 2,688
- Values > 2,688 will saturate/clip
- Common in LLM activations which can exceed 10,000+
2. NVIDIA'S TWO-LEVEL APPROACH:
- First, FP32 scale normalizes tensor into (FP4 × E4M3) range
- Then, E4M3 block scales handle local variation
- Result: dynamic range extended to FP32 limits (~1e-38 to ~3.4e38)
3. PAPER QUOTE (Section 2):
"NVFP4 employs a two-level scaling strategy, which combines a
fine-grained FP8 scale factor with an FP32 scale applied at
the tensor level."
4. QUANTIZATION FORMULA (Appendix B.1):
s_enc = (6 × 448) / amax_x
where amax_x is the tensor-wide absolute maximum
""")
print("\n" + "=" * 80)
print("RECOMMENDATION FOR MLX ISSUE")
print("=" * 80)
print("""
To achieve NVIDIA NVFP4 compatibility, MLX needs:
1. Add FP32 per-tensor global scale factor
- Computed as: s_enc = (6 * 448) / tensor_amax
- Stored alongside quantized tensor
2. Modify quantization pipeline:
- Apply global scale BEFORE block quantization
- Store global decode scale: s_dec = 1 / s_enc
3. Modify dequantization:
- Apply block E4M3 scales first
- Then apply global FP32 decode scale
Reference implementations:
- NVIDIA reference: vLLM csrc/quantization/fp4/nvfp4_utils.cuh
SFScale = (448.f / (Alpha_A / 6.f))
which simplifies to: (448 * 6) / Alpha_A = 2688 / tensor_amax
- vLLM Marlin (optimized): csrc/quantization/fp4/marlin/marlin_fp4_kernel.cu
Uses per-channel global scales instead of per-tensor
May be a reasonable middle ground for inference workloads
""")
# Print citations
print("\n" + "=" * 80)
print("CITATIONS")
print("=" * 80)
print("""
[1] NVIDIA. "Pretraining Large Language Models with NVFP4."
arXiv:2509.25149, September 2025.
https://arxiv.org/abs/2509.25149
[2] Alvarez, E., et al. "Introducing NVFP4 for Efficient and Accurate
Low-Precision Inference." NVIDIA Developer Blog, June 2025.
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/
[3] vLLM Project. "NVFP4 Quantization Implementation."
https://github.com/vllm-project/vllm/tree/main/csrc/quantization/fp4
""")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment