Skip to content

Instantly share code, notes, and snippets.

@Nottlespike
Created January 26, 2026 00:05
Show Gist options
  • Select an option

  • Save Nottlespike/1ffe10348ca836e72b5ed150f3716003 to your computer and use it in GitHub Desktop.

Select an option

Save Nottlespike/1ffe10348ca836e72b5ed150f3716003 to your computer and use it in GitHub Desktop.
// paged_attention.metal - Paged Attention for Apple Silicon
//
// Implements vLLM-style paged attention adapted for Metal simdgroup architecture.
// Paged attention decouples logical token positions from physical memory layout,
// enabling efficient batch serving with variable-length sequences.
//
// Key differences from flash_attention.metal:
// - KV cache is organized in fixed-size blocks (pages)
// - Block tables map logical block indices to physical block addresses
// - Each sequence can have a different context length
// - Designed for decode phase: Q has exactly 1 token per sequence
// - Supports batch serving (multiple sequences with different lengths)
//
// Kernel variants:
// 1. paged_attention_v1 - Single-pass (short-medium contexts)
// 2. paged_attention_v2 - Multi-partition (long contexts, >1024 tokens)
// 3. paged_attention_v1_fp4 - FP4-quantized KV blocks
// 4. paged_attention_v1_int4 - INT4-quantized KV blocks
//
// KV block layout (physical):
// [num_blocks, num_kv_heads, block_size, head_dim]
//
// Block table:
// [num_seqs, max_blocks_per_seq] - maps (seq, logical_block) -> physical_block
//
// Dispatch:
// paged_attention_v1: [num_seqs, num_heads_q, 1] threadgroups, 128 threads/tg
// paged_attention_v2: [num_seqs, num_heads_q, num_partitions] threadgroups
//
// CUDA -> Metal mapping:
// vLLM uses 1 warp per head (32 threads), multiple warps for parallel reduce.
// Metal uses 1 simdgroup per head, 4 simdgroups per threadgroup for occupancy.
// Since decode Q is always 1 token, we assign 1 simdgroup per (seq, head) pair.
// The other 3 simdgroups in the threadgroup cooperate on KV block loading.
#include <metal_stdlib>
using namespace metal;
// ---------------------------------------------------------------------------
// Configuration
// ---------------------------------------------------------------------------
// Tokens per KV block. Must match the page size used by the Python scheduler.
// 16 is the standard vLLM block size; powers of 2 enable shift-based indexing.
constant constexpr uint BLOCK_SIZE = 16;
// Maximum supported head dimension
constant constexpr uint HEAD_DIM_MAX = 128;
// Threads per simdgroup (Apple Silicon fixed at 32)
constant constexpr uint SIMD_SIZE = 32;
// Simdgroups per threadgroup. 4 gives 128 threads total.
// All 4 simdgroups cooperate on loading KV blocks; simdgroup 0 does the
// attention computation. This hides load latency behind compute.
constant constexpr uint NUM_SIMDGROUPS = 4;
constant constexpr uint THREADS_PER_TG = SIMD_SIZE * NUM_SIMDGROUPS; // 128
// Number of KV blocks loaded into threadgroup memory simultaneously.
// Double-buffering: compute on one while loading the next.
constant constexpr uint KV_TILES = 2;
// For v2 partitioning: maximum tokens processed per partition
constant constexpr uint PARTITION_SIZE = 256;
// FP4 packing
constant constexpr uint FP4_PER_UINT = 8;
// ---------------------------------------------------------------------------
// Utility functions (shared with flash_attention.metal patterns)
// ---------------------------------------------------------------------------
inline float simd_reduce_sum(float val) {
val += simd_shuffle_xor(val, 16);
val += simd_shuffle_xor(val, 8);
val += simd_shuffle_xor(val, 4);
val += simd_shuffle_xor(val, 2);
val += simd_shuffle_xor(val, 1);
return val;
}
inline float simd_reduce_max(float val) {
val = max(val, simd_shuffle_xor(val, 16));
val = max(val, simd_shuffle_xor(val, 8));
val = max(val, simd_shuffle_xor(val, 4));
val = max(val, simd_shuffle_xor(val, 2));
val = max(val, simd_shuffle_xor(val, 1));
return val;
}
// ---------------------------------------------------------------------------
// FP4/INT4 dequantization helpers (same as flash_attention.metal)
// ---------------------------------------------------------------------------
inline half dequant_fp4(uint nibble, half scale) {
uint sign_bit = (nibble >> 3) & 1;
uint exp_bits = (nibble >> 1) & 0x3;
uint man_bit = nibble & 1;
half magnitude;
if (exp_bits == 0) {
magnitude = half(man_bit) * half(0.25h);
} else {
half power = half(1u << (exp_bits - 1));
half mantissa = half(1.0h) + half(man_bit) * half(0.5h);
magnitude = power * mantissa;
}
half result = sign_bit ? -magnitude : magnitude;
return result * scale;
}
inline half dequant_int4(uint nibble, half scale) {
int signed_val = int(nibble & 0xFu) - 8;
return half(signed_val) * scale;
}
// ---------------------------------------------------------------------------
// Paged Attention V1 - Single-pass decode attention
//
// For each sequence in the batch, computes:
// output[seq][head] = softmax(Q[seq][head] @ K_cache^T / sqrt(d)) @ V_cache
//
// where K_cache and V_cache are scattered across physical blocks according
// to the block table.
//
// Dispatch: [num_seqs, num_heads_q, 1] threadgroups
// THREADS_PER_TG (128) threads per threadgroup
//
// Each threadgroup handles one (sequence, Q-head) pair.
// The Q vector is loaded into registers distributed across simdgroup 0.
// All 4 simdgroups cooperate to load KV blocks into threadgroup memory.
// Simdgroup 0 computes the dot products and online softmax.
// ---------------------------------------------------------------------------
kernel void paged_attention_v1(
device const half* Q [[buffer(0)]], // [num_seqs, num_heads_q, head_dim]
device const half* K_cache [[buffer(1)]], // [num_blocks, num_kv_heads, block_size, head_dim]
device const half* V_cache [[buffer(2)]], // [num_blocks, num_kv_heads, block_size, head_dim]
device const int* block_tables [[buffer(3)]], // [num_seqs, max_blocks_per_seq]
device const int* context_lens [[buffer(4)]], // [num_seqs]
device half* output [[buffer(5)]], // [num_seqs, num_heads_q, head_dim]
constant uint& num_seqs [[buffer(6)]],
constant uint& num_heads_q [[buffer(7)]],
constant uint& num_kv_heads [[buffer(8)]],
constant uint& head_dim [[buffer(9)]],
constant uint& max_blocks_per_seq [[buffer(10)]],
constant float& scale [[buffer(11)]],
uint3 tgid [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint lane_id [[thread_index_in_simdgroup]],
uint sg_id [[simdgroup_index_in_threadgroup]]
) {
const uint seq_idx = tgid.x;
const uint head_q = tgid.y;
if (seq_idx >= num_seqs) return;
// GQA: map Q head to KV head
const uint gqa_ratio = num_heads_q / num_kv_heads;
const uint head_kv = head_q / gqa_ratio;
// Context length for this sequence
const int ctx_len = context_lens[seq_idx];
if (ctx_len <= 0) return;
const uint context_len = uint(ctx_len);
const uint num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Q layout: [num_seqs, num_heads_q, head_dim]
const uint q_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
// Load Q vector into registers (distributed across lanes of simdgroup 0)
// Each lane holds head_dim/32 consecutive elements.
const uint elems_per_lane = head_dim / SIMD_SIZE;
float q_reg[HEAD_DIM_MAX / SIMD_SIZE];
if (sg_id == 0) {
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
q_reg[i] = (d < head_dim) ? float(Q[q_offset + d]) : 0.0f;
}
}
// Threadgroup memory for one KV block at a time (double-buffered)
// Each block has BLOCK_SIZE tokens, each with head_dim elements.
// Memory: 2 * BLOCK_SIZE * HEAD_DIM_MAX * sizeof(half) = 2 * 16 * 128 * 2 = 8 KB
// Well within the 32 KB threadgroup memory budget.
threadgroup half K_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
threadgroup half V_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
// Online softmax state (simdgroup 0 only, but all lanes participate)
float m_prev = -INFINITY;
float l_prev = 0.0f;
float o_acc[HEAD_DIM_MAX / SIMD_SIZE];
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] = 0.0f;
}
// KV block stride: [num_blocks, num_kv_heads, block_size, head_dim]
const uint kv_head_stride = BLOCK_SIZE * head_dim;
const uint kv_block_stride = num_kv_heads * kv_head_stride;
// Block table for this sequence
device const int* seq_block_table = block_tables + seq_idx * max_blocks_per_seq;
// Preload first block into buffer 0
{
int phys_block = seq_block_table[0];
uint kv_base = uint(phys_block) * kv_block_stride + head_kv * kv_head_stride;
uint elems_to_load = BLOCK_SIZE * head_dim;
uint loads_per_thread = (elems_to_load + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < elems_to_load) {
uint token_in_block = idx / head_dim;
uint d = idx % head_dim;
K_smem[0][token_in_block][d] = K_cache[kv_base + token_in_block * head_dim + d];
V_smem[0][token_in_block][d] = V_cache[kv_base + token_in_block * head_dim + d];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ---------------------------------------------------------------------------
// Main loop: iterate over KV blocks with double-buffering
// ---------------------------------------------------------------------------
uint buf_compute = 0;
for (uint block_idx = 0; block_idx < num_blocks; ++block_idx) {
uint buf_load = 1 - buf_compute;
// Async load next block (all threads participate)
if (block_idx + 1 < num_blocks) {
int next_phys_block = seq_block_table[block_idx + 1];
uint kv_base = uint(next_phys_block) * kv_block_stride + head_kv * kv_head_stride;
uint elems_to_load = BLOCK_SIZE * head_dim;
uint loads_per_thread = (elems_to_load + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < elems_to_load) {
uint token_in_block = idx / head_dim;
uint d = idx % head_dim;
K_smem[buf_load][token_in_block][d] = K_cache[kv_base + token_in_block * head_dim + d];
V_smem[buf_load][token_in_block][d] = V_cache[kv_base + token_in_block * head_dim + d];
}
}
}
// Compute attention for current block (simdgroup 0)
// Determine valid tokens in this block
uint block_start_token = block_idx * BLOCK_SIZE;
uint block_tokens = min(uint(BLOCK_SIZE), context_len - block_start_token);
if (sg_id == 0) {
// Compute QK^T for each valid token in this block
float scores[BLOCK_SIZE];
for (uint t = 0; t < block_tokens; ++t) {
float dot = 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
dot += q_reg[i] * float(K_smem[buf_compute][t][d]);
}
dot = simd_reduce_sum(dot);
scores[t] = dot * scale;
}
// Mask invalid positions
for (uint t = block_tokens; t < BLOCK_SIZE; ++t) {
scores[t] = -INFINITY;
}
// Online softmax update for this block
float m_block = -INFINITY;
for (uint t = 0; t < block_tokens; ++t) {
m_block = max(m_block, scores[t]);
}
float m_new = max(m_prev, m_block);
float correction = exp(m_prev - m_new);
// Rescale running sum and add new exponentials
float l_new = l_prev * correction;
for (uint t = 0; t < block_tokens; ++t) {
l_new += exp(scores[t] - m_new);
}
// Rescale output accumulator
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] *= correction;
}
// Accumulate weighted V
for (uint t = 0; t < block_tokens; ++t) {
float p = exp(scores[t] - m_new);
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
o_acc[i] += p * float(V_smem[buf_compute][t][d]);
}
}
m_prev = m_new;
l_prev = l_new;
}
// Barrier before swapping buffers (all threads must finish loading)
threadgroup_barrier(mem_flags::mem_threadgroup);
buf_compute = buf_load;
}
// ---------------------------------------------------------------------------
// Store output (simdgroup 0 only)
// ---------------------------------------------------------------------------
if (sg_id == 0) {
const uint o_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
float inv_l = (l_prev > 0.0f) ? (1.0f / l_prev) : 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
if (d < head_dim) {
output[o_offset + d] = half(o_acc[i] * inv_l);
}
}
}
}
// ---------------------------------------------------------------------------
// Paged Attention V2 - Multi-partition for long contexts
//
// For sequences with many KV blocks, v1 serializes over all blocks in one
// threadgroup, limiting parallelism. V2 partitions the blocks across multiple
// threadgroups along the z-axis, each computing a partial softmax result.
// A final reduction combines partitions using the log-sum-exp trick.
//
// Phase 1 (this kernel): Each partition computes:
// - partial output accumulator (unnormalized)
// - partial max logit (m)
// - partial sum of exponentials (l)
//
// Phase 2 (paged_attention_v2_reduce): Combines partitions.
//
// Dispatch: [num_seqs, num_heads_q, num_partitions] threadgroups
// ---------------------------------------------------------------------------
kernel void paged_attention_v2(
device const half* Q [[buffer(0)]],
device const half* K_cache [[buffer(1)]],
device const half* V_cache [[buffer(2)]],
device const int* block_tables [[buffer(3)]],
device const int* context_lens [[buffer(4)]],
device float* partial_out [[buffer(5)]], // [num_seqs, num_heads_q, max_partitions, head_dim]
device float* partial_m [[buffer(6)]], // [num_seqs, num_heads_q, max_partitions]
device float* partial_l [[buffer(7)]], // [num_seqs, num_heads_q, max_partitions]
constant uint& num_seqs [[buffer(8)]],
constant uint& num_heads_q [[buffer(9)]],
constant uint& num_kv_heads [[buffer(10)]],
constant uint& head_dim [[buffer(11)]],
constant uint& max_blocks_per_seq [[buffer(12)]],
constant uint& max_partitions [[buffer(13)]],
constant float& scale [[buffer(14)]],
uint3 tgid [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint lane_id [[thread_index_in_simdgroup]],
uint sg_id [[simdgroup_index_in_threadgroup]]
) {
const uint seq_idx = tgid.x;
const uint head_q = tgid.y;
const uint partition_idx = tgid.z;
if (seq_idx >= num_seqs) return;
const uint gqa_ratio = num_heads_q / num_kv_heads;
const uint head_kv = head_q / gqa_ratio;
const int ctx_len = context_lens[seq_idx];
if (ctx_len <= 0) return;
const uint context_len = uint(ctx_len);
// Determine which blocks this partition handles
// Each partition handles PARTITION_SIZE / BLOCK_SIZE blocks
const uint blocks_per_partition = PARTITION_SIZE / BLOCK_SIZE; // 16
const uint total_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const uint partition_start_block = partition_idx * blocks_per_partition;
if (partition_start_block >= total_blocks) return; // This partition has no work
const uint partition_end_block = min(partition_start_block + blocks_per_partition, total_blocks);
const uint partition_num_blocks = partition_end_block - partition_start_block;
// Load Q
const uint q_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
const uint elems_per_lane = head_dim / SIMD_SIZE;
float q_reg[HEAD_DIM_MAX / SIMD_SIZE];
if (sg_id == 0) {
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
q_reg[i] = (d < head_dim) ? float(Q[q_offset + d]) : 0.0f;
}
}
threadgroup half K_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
threadgroup half V_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
float m_prev = -INFINITY;
float l_prev = 0.0f;
float o_acc[HEAD_DIM_MAX / SIMD_SIZE];
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] = 0.0f;
}
const uint kv_head_stride = BLOCK_SIZE * head_dim;
const uint kv_block_stride = num_kv_heads * kv_head_stride;
device const int* seq_block_table = block_tables + seq_idx * max_blocks_per_seq;
// Preload first block of this partition
{
uint blk = partition_start_block;
int phys_block = seq_block_table[blk];
uint kv_base = uint(phys_block) * kv_block_stride + head_kv * kv_head_stride;
uint elems_to_load = BLOCK_SIZE * head_dim;
uint loads_per_thread = (elems_to_load + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < elems_to_load) {
uint token_in_block = idx / head_dim;
uint d = idx % head_dim;
K_smem[0][token_in_block][d] = K_cache[kv_base + token_in_block * head_dim + d];
V_smem[0][token_in_block][d] = V_cache[kv_base + token_in_block * head_dim + d];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint buf_compute = 0;
for (uint blk_offset = 0; blk_offset < partition_num_blocks; ++blk_offset) {
uint buf_load = 1 - buf_compute;
uint abs_block = partition_start_block + blk_offset;
// Load next block
if (blk_offset + 1 < partition_num_blocks) {
uint next_abs_block = partition_start_block + blk_offset + 1;
int next_phys_block = seq_block_table[next_abs_block];
uint kv_base = uint(next_phys_block) * kv_block_stride + head_kv * kv_head_stride;
uint elems_to_load = BLOCK_SIZE * head_dim;
uint loads_per_thread = (elems_to_load + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < elems_to_load) {
uint token_in_block = idx / head_dim;
uint d = idx % head_dim;
K_smem[buf_load][token_in_block][d] = K_cache[kv_base + token_in_block * head_dim + d];
V_smem[buf_load][token_in_block][d] = V_cache[kv_base + token_in_block * head_dim + d];
}
}
}
// Compute on current block
uint block_start_token = abs_block * BLOCK_SIZE;
uint block_tokens = min(uint(BLOCK_SIZE), context_len - block_start_token);
if (sg_id == 0) {
float scores[BLOCK_SIZE];
for (uint t = 0; t < block_tokens; ++t) {
float dot = 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
dot += q_reg[i] * float(K_smem[buf_compute][t][d]);
}
dot = simd_reduce_sum(dot);
scores[t] = dot * scale;
}
for (uint t = block_tokens; t < BLOCK_SIZE; ++t) {
scores[t] = -INFINITY;
}
float m_block = -INFINITY;
for (uint t = 0; t < block_tokens; ++t) {
m_block = max(m_block, scores[t]);
}
float m_new = max(m_prev, m_block);
float correction = exp(m_prev - m_new);
float l_new = l_prev * correction;
for (uint t = 0; t < block_tokens; ++t) {
l_new += exp(scores[t] - m_new);
}
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] *= correction;
}
for (uint t = 0; t < block_tokens; ++t) {
float p = exp(scores[t] - m_new);
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
o_acc[i] += p * float(V_smem[buf_compute][t][d]);
}
}
m_prev = m_new;
l_prev = l_new;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
buf_compute = buf_load;
}
// Store partial results (simdgroup 0)
if (sg_id == 0) {
// partial_out: [num_seqs, num_heads_q, max_partitions, head_dim]
const uint po_offset = ((seq_idx * num_heads_q + head_q) * max_partitions + partition_idx) * head_dim;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
if (d < head_dim) {
partial_out[po_offset + d] = o_acc[i];
}
}
// partial_m/l: [num_seqs, num_heads_q, max_partitions]
if (lane_id == 0) {
const uint pm_offset = (seq_idx * num_heads_q + head_q) * max_partitions + partition_idx;
partial_m[pm_offset] = m_prev;
partial_l[pm_offset] = l_prev;
}
}
}
// ---------------------------------------------------------------------------
// Paged Attention V2 - Reduction kernel
//
// Combines partial results from paged_attention_v2 using the log-sum-exp trick:
// m_global = max(m_0, m_1, ..., m_P)
// l_global = sum_p(l_p * exp(m_p - m_global))
// O = sum_p(O_p * exp(m_p - m_global)) / l_global
//
// Dispatch: [num_seqs, num_heads_q, 1] threadgroups, SIMD_SIZE threads
// ---------------------------------------------------------------------------
kernel void paged_attention_v2_reduce(
device const float* partial_out [[buffer(0)]], // [num_seqs, num_heads_q, max_partitions, head_dim]
device const float* partial_m [[buffer(1)]], // [num_seqs, num_heads_q, max_partitions]
device const float* partial_l [[buffer(2)]], // [num_seqs, num_heads_q, max_partitions]
device const int* context_lens [[buffer(3)]], // [num_seqs]
device half* output [[buffer(4)]], // [num_seqs, num_heads_q, head_dim]
constant uint& num_seqs [[buffer(5)]],
constant uint& num_heads_q [[buffer(6)]],
constant uint& head_dim [[buffer(7)]],
constant uint& max_partitions [[buffer(8)]],
uint3 tgid [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint lane_id [[thread_index_in_simdgroup]]
) {
const uint seq_idx = tgid.x;
const uint head_q = tgid.y;
if (seq_idx >= num_seqs) return;
const int ctx_len = context_lens[seq_idx];
if (ctx_len <= 0) return;
const uint context_len = uint(ctx_len);
// How many partitions were actually used for this sequence?
const uint total_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const uint blocks_per_partition = PARTITION_SIZE / BLOCK_SIZE;
const uint num_parts = (total_blocks + blocks_per_partition - 1) / blocks_per_partition;
// If only 1 partition, the v2 kernel result is already final (just copy + normalize)
const uint pm_base = (seq_idx * num_heads_q + head_q) * max_partitions;
const uint po_base = ((seq_idx * num_heads_q + head_q) * max_partitions) * head_dim;
// Find global max across partitions
float m_global = -INFINITY;
for (uint p = 0; p < num_parts; ++p) {
float mp = partial_m[pm_base + p];
m_global = max(m_global, mp);
}
// Compute global l and accumulate output
float l_global = 0.0f;
const uint elems_per_lane = head_dim / SIMD_SIZE;
float o_final[HEAD_DIM_MAX / SIMD_SIZE];
for (uint i = 0; i < elems_per_lane; ++i) {
o_final[i] = 0.0f;
}
for (uint p = 0; p < num_parts; ++p) {
float mp = partial_m[pm_base + p];
float lp = partial_l[pm_base + p];
float w = exp(mp - m_global) * lp;
l_global += w;
// Weight this partition's output contribution
float scale_p = exp(mp - m_global);
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
if (d < head_dim) {
o_final[i] += partial_out[po_base + p * head_dim + d] * scale_p;
}
}
}
// Normalize and store
const uint o_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
float inv_l = (l_global > 0.0f) ? (1.0f / l_global) : 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
if (d < head_dim) {
output[o_offset + d] = half(o_final[i] * inv_l);
}
}
}
// ---------------------------------------------------------------------------
// Paged Attention V1 - FP4 quantized KV cache
//
// KV blocks stored as packed FP4 E2M1 with per-token-per-head scales.
//
// K_cache_packed: [num_blocks, num_kv_heads, block_size, head_dim/8] (uint32)
// K_scales: [num_blocks, num_kv_heads, block_size] (half)
// ---------------------------------------------------------------------------
kernel void paged_attention_v1_fp4(
device const half* Q [[buffer(0)]],
device const uint* K_cache_packed [[buffer(1)]],
device const uint* V_cache_packed [[buffer(2)]],
device const half* K_scales [[buffer(3)]],
device const half* V_scales [[buffer(4)]],
device const int* block_tables [[buffer(5)]],
device const int* context_lens [[buffer(6)]],
device half* output [[buffer(7)]],
constant uint& num_seqs [[buffer(8)]],
constant uint& num_heads_q [[buffer(9)]],
constant uint& num_kv_heads [[buffer(10)]],
constant uint& head_dim [[buffer(11)]],
constant uint& max_blocks_per_seq [[buffer(12)]],
constant float& scale [[buffer(13)]],
uint3 tgid [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint lane_id [[thread_index_in_simdgroup]],
uint sg_id [[simdgroup_index_in_threadgroup]]
) {
const uint seq_idx = tgid.x;
const uint head_q = tgid.y;
if (seq_idx >= num_seqs) return;
const uint gqa_ratio = num_heads_q / num_kv_heads;
const uint head_kv = head_q / gqa_ratio;
const int ctx_len = context_lens[seq_idx];
if (ctx_len <= 0) return;
const uint context_len = uint(ctx_len);
const uint num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const uint q_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
const uint elems_per_lane = head_dim / SIMD_SIZE;
float q_reg[HEAD_DIM_MAX / SIMD_SIZE];
if (sg_id == 0) {
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
q_reg[i] = (d < head_dim) ? float(Q[q_offset + d]) : 0.0f;
}
}
// Packed dimensions
const uint packed_head_dim = head_dim / FP4_PER_UINT;
// KV packed layout: [num_blocks, num_kv_heads, block_size, packed_head_dim]
const uint kv_packed_head_stride = BLOCK_SIZE * packed_head_dim;
const uint kv_packed_block_stride = num_kv_heads * kv_packed_head_stride;
// Scale layout: [num_blocks, num_kv_heads, block_size]
const uint scale_head_stride = BLOCK_SIZE;
const uint scale_block_stride = num_kv_heads * scale_head_stride;
threadgroup half K_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
threadgroup half V_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
float m_prev = -INFINITY;
float l_prev = 0.0f;
float o_acc[HEAD_DIM_MAX / SIMD_SIZE];
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] = 0.0f;
}
device const int* seq_block_table = block_tables + seq_idx * max_blocks_per_seq;
// Preload first block (dequantize FP4 -> FP16 into threadgroup memory)
{
int phys_block = seq_block_table[0];
uint k_packed_base = uint(phys_block) * kv_packed_block_stride + head_kv * kv_packed_head_stride;
uint k_scale_base = uint(phys_block) * scale_block_stride + head_kv * scale_head_stride;
uint packed_elems = BLOCK_SIZE * packed_head_dim;
uint loads_per_thread = (packed_elems + THREADS_PER_TG - 1) / THREADS_PER_TG;
// Load and dequant K
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < packed_elems) {
uint token_in_block = idx / packed_head_dim;
uint packed_col = idx % packed_head_dim;
half s = K_scales[k_scale_base + token_in_block];
uint word = K_cache_packed[k_packed_base + token_in_block * packed_head_dim + packed_col];
uint base_d = packed_col * FP4_PER_UINT;
for (uint j = 0; j < FP4_PER_UINT && base_d + j < head_dim; ++j) {
uint nibble = (word >> (j * 4)) & 0xFu;
K_smem[0][token_in_block][base_d + j] = dequant_fp4(nibble, s);
}
}
}
// Load and dequant V
uint v_packed_base = uint(phys_block) * kv_packed_block_stride + head_kv * kv_packed_head_stride;
uint v_scale_base = uint(phys_block) * scale_block_stride + head_kv * scale_head_stride;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < packed_elems) {
uint token_in_block = idx / packed_head_dim;
uint packed_col = idx % packed_head_dim;
half s = V_scales[v_scale_base + token_in_block];
uint word = V_cache_packed[v_packed_base + token_in_block * packed_head_dim + packed_col];
uint base_d = packed_col * FP4_PER_UINT;
for (uint j = 0; j < FP4_PER_UINT && base_d + j < head_dim; ++j) {
uint nibble = (word >> (j * 4)) & 0xFu;
V_smem[0][token_in_block][base_d + j] = dequant_fp4(nibble, s);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint buf_compute = 0;
for (uint block_idx = 0; block_idx < num_blocks; ++block_idx) {
uint buf_load = 1 - buf_compute;
// Load next block with FP4 dequant
if (block_idx + 1 < num_blocks) {
int next_phys_block = seq_block_table[block_idx + 1];
uint k_packed_base = uint(next_phys_block) * kv_packed_block_stride + head_kv * kv_packed_head_stride;
uint k_scale_base = uint(next_phys_block) * scale_block_stride + head_kv * scale_head_stride;
uint v_packed_base = k_packed_base; // Same layout for V
uint v_scale_base = uint(next_phys_block) * scale_block_stride + head_kv * scale_head_stride;
uint packed_elems = BLOCK_SIZE * packed_head_dim;
uint loads_per_thread = (packed_elems + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < packed_elems) {
uint token_in_block = idx / packed_head_dim;
uint packed_col = idx % packed_head_dim;
half ks = K_scales[k_scale_base + token_in_block];
uint k_word = K_cache_packed[k_packed_base + token_in_block * packed_head_dim + packed_col];
half vs = V_scales[v_scale_base + token_in_block];
uint v_word = V_cache_packed[v_packed_base + token_in_block * packed_head_dim + packed_col];
uint base_d = packed_col * FP4_PER_UINT;
for (uint j = 0; j < FP4_PER_UINT && base_d + j < head_dim; ++j) {
uint k_nibble = (k_word >> (j * 4)) & 0xFu;
uint v_nibble = (v_word >> (j * 4)) & 0xFu;
K_smem[buf_load][token_in_block][base_d + j] = dequant_fp4(k_nibble, ks);
V_smem[buf_load][token_in_block][base_d + j] = dequant_fp4(v_nibble, vs);
}
}
}
}
uint block_start_token = block_idx * BLOCK_SIZE;
uint block_tokens = min(uint(BLOCK_SIZE), context_len - block_start_token);
if (sg_id == 0) {
float scores[BLOCK_SIZE];
for (uint t = 0; t < block_tokens; ++t) {
float dot = 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
dot += q_reg[i] * float(K_smem[buf_compute][t][d]);
}
dot = simd_reduce_sum(dot);
scores[t] = dot * scale;
}
for (uint t = block_tokens; t < BLOCK_SIZE; ++t) {
scores[t] = -INFINITY;
}
float m_block = -INFINITY;
for (uint t = 0; t < block_tokens; ++t) {
m_block = max(m_block, scores[t]);
}
float m_new = max(m_prev, m_block);
float correction = exp(m_prev - m_new);
float l_new = l_prev * correction;
for (uint t = 0; t < block_tokens; ++t) {
l_new += exp(scores[t] - m_new);
}
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] *= correction;
}
for (uint t = 0; t < block_tokens; ++t) {
float p = exp(scores[t] - m_new);
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
o_acc[i] += p * float(V_smem[buf_compute][t][d]);
}
}
m_prev = m_new;
l_prev = l_new;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
buf_compute = buf_load;
}
if (sg_id == 0) {
const uint o_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
float inv_l = (l_prev > 0.0f) ? (1.0f / l_prev) : 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
if (d < head_dim) {
output[o_offset + d] = half(o_acc[i] * inv_l);
}
}
}
}
// ---------------------------------------------------------------------------
// Paged Attention V1 - INT4 quantized KV cache
//
// Same structure as FP4 variant but uses signed 4-bit integer dequantization.
// ---------------------------------------------------------------------------
kernel void paged_attention_v1_int4(
device const half* Q [[buffer(0)]],
device const uint* K_cache_packed [[buffer(1)]],
device const uint* V_cache_packed [[buffer(2)]],
device const half* K_scales [[buffer(3)]],
device const half* V_scales [[buffer(4)]],
device const int* block_tables [[buffer(5)]],
device const int* context_lens [[buffer(6)]],
device half* output [[buffer(7)]],
constant uint& num_seqs [[buffer(8)]],
constant uint& num_heads_q [[buffer(9)]],
constant uint& num_kv_heads [[buffer(10)]],
constant uint& head_dim [[buffer(11)]],
constant uint& max_blocks_per_seq [[buffer(12)]],
constant float& scale [[buffer(13)]],
uint3 tgid [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint lane_id [[thread_index_in_simdgroup]],
uint sg_id [[simdgroup_index_in_threadgroup]]
) {
const uint seq_idx = tgid.x;
const uint head_q = tgid.y;
if (seq_idx >= num_seqs) return;
const uint gqa_ratio = num_heads_q / num_kv_heads;
const uint head_kv = head_q / gqa_ratio;
const int ctx_len = context_lens[seq_idx];
if (ctx_len <= 0) return;
const uint context_len = uint(ctx_len);
const uint num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const uint q_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
const uint elems_per_lane = head_dim / SIMD_SIZE;
float q_reg[HEAD_DIM_MAX / SIMD_SIZE];
if (sg_id == 0) {
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
q_reg[i] = (d < head_dim) ? float(Q[q_offset + d]) : 0.0f;
}
}
const uint packed_head_dim = head_dim / FP4_PER_UINT;
const uint kv_packed_head_stride = BLOCK_SIZE * packed_head_dim;
const uint kv_packed_block_stride = num_kv_heads * kv_packed_head_stride;
const uint scale_head_stride = BLOCK_SIZE;
const uint scale_block_stride = num_kv_heads * scale_head_stride;
threadgroup half K_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
threadgroup half V_smem[KV_TILES][BLOCK_SIZE][HEAD_DIM_MAX];
float m_prev = -INFINITY;
float l_prev = 0.0f;
float o_acc[HEAD_DIM_MAX / SIMD_SIZE];
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] = 0.0f;
}
device const int* seq_block_table = block_tables + seq_idx * max_blocks_per_seq;
// Preload first block with INT4 dequant
{
int phys_block = seq_block_table[0];
uint k_packed_base = uint(phys_block) * kv_packed_block_stride + head_kv * kv_packed_head_stride;
uint k_scale_base = uint(phys_block) * scale_block_stride + head_kv * scale_head_stride;
uint v_packed_base = k_packed_base;
uint v_scale_base = k_scale_base;
uint packed_elems = BLOCK_SIZE * packed_head_dim;
uint loads_per_thread = (packed_elems + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < packed_elems) {
uint token_in_block = idx / packed_head_dim;
uint packed_col = idx % packed_head_dim;
half ks = K_scales[k_scale_base + token_in_block];
uint k_word = K_cache_packed[k_packed_base + token_in_block * packed_head_dim + packed_col];
half vs = V_scales[v_scale_base + token_in_block];
uint v_word = V_cache_packed[v_packed_base + token_in_block * packed_head_dim + packed_col];
uint base_d = packed_col * FP4_PER_UINT;
for (uint j = 0; j < FP4_PER_UINT && base_d + j < head_dim; ++j) {
uint k_nibble = (k_word >> (j * 4)) & 0xFu;
uint v_nibble = (v_word >> (j * 4)) & 0xFu;
K_smem[0][token_in_block][base_d + j] = dequant_int4(k_nibble, ks);
V_smem[0][token_in_block][base_d + j] = dequant_int4(v_nibble, vs);
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint buf_compute = 0;
for (uint block_idx = 0; block_idx < num_blocks; ++block_idx) {
uint buf_load = 1 - buf_compute;
if (block_idx + 1 < num_blocks) {
int next_phys_block = seq_block_table[block_idx + 1];
uint k_packed_base = uint(next_phys_block) * kv_packed_block_stride + head_kv * kv_packed_head_stride;
uint k_scale_base = uint(next_phys_block) * scale_block_stride + head_kv * scale_head_stride;
uint v_packed_base = k_packed_base;
uint v_scale_base = k_scale_base;
uint packed_elems = BLOCK_SIZE * packed_head_dim;
uint loads_per_thread = (packed_elems + THREADS_PER_TG - 1) / THREADS_PER_TG;
for (uint i = 0; i < loads_per_thread; ++i) {
uint idx = tid_in_tg + i * THREADS_PER_TG;
if (idx < packed_elems) {
uint token_in_block = idx / packed_head_dim;
uint packed_col = idx % packed_head_dim;
half ks = K_scales[k_scale_base + token_in_block];
uint k_word = K_cache_packed[k_packed_base + token_in_block * packed_head_dim + packed_col];
half vs = V_scales[v_scale_base + token_in_block];
uint v_word = V_cache_packed[v_packed_base + token_in_block * packed_head_dim + packed_col];
uint base_d = packed_col * FP4_PER_UINT;
for (uint j = 0; j < FP4_PER_UINT && base_d + j < head_dim; ++j) {
uint k_nibble = (k_word >> (j * 4)) & 0xFu;
uint v_nibble = (v_word >> (j * 4)) & 0xFu;
K_smem[buf_load][token_in_block][base_d + j] = dequant_int4(k_nibble, ks);
V_smem[buf_load][token_in_block][base_d + j] = dequant_int4(v_nibble, vs);
}
}
}
}
uint block_start_token = block_idx * BLOCK_SIZE;
uint block_tokens = min(uint(BLOCK_SIZE), context_len - block_start_token);
if (sg_id == 0) {
float scores[BLOCK_SIZE];
for (uint t = 0; t < block_tokens; ++t) {
float dot = 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
dot += q_reg[i] * float(K_smem[buf_compute][t][d]);
}
dot = simd_reduce_sum(dot);
scores[t] = dot * scale;
}
for (uint t = block_tokens; t < BLOCK_SIZE; ++t) {
scores[t] = -INFINITY;
}
float m_block = -INFINITY;
for (uint t = 0; t < block_tokens; ++t) {
m_block = max(m_block, scores[t]);
}
float m_new = max(m_prev, m_block);
float correction = exp(m_prev - m_new);
float l_new = l_prev * correction;
for (uint t = 0; t < block_tokens; ++t) {
l_new += exp(scores[t] - m_new);
}
for (uint i = 0; i < elems_per_lane; ++i) {
o_acc[i] *= correction;
}
for (uint t = 0; t < block_tokens; ++t) {
float p = exp(scores[t] - m_new);
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
o_acc[i] += p * float(V_smem[buf_compute][t][d]);
}
}
m_prev = m_new;
l_prev = l_new;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
buf_compute = buf_load;
}
if (sg_id == 0) {
const uint o_offset = seq_idx * num_heads_q * head_dim + head_q * head_dim;
float inv_l = (l_prev > 0.0f) ? (1.0f / l_prev) : 0.0f;
for (uint i = 0; i < elems_per_lane; ++i) {
uint d = lane_id * elems_per_lane + i;
if (d < head_dim) {
output[o_offset + d] = half(o_acc[i] * inv_l);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment