An mask-and-scan radix sort implementation as explained to me by Claude (Opus 4.5)
[1] https://web.archive.org/web/20221012085306/https://vgc.poly.edu/~csilva/papers/cgf.pdf [2] https://github.com/mark-poscablo/gpu-radix-sort
An mask-and-scan radix sort implementation as explained to me by Claude (Opus 4.5)
[1] https://web.archive.org/web/20221012085306/https://vgc.poly.edu/~csilva/papers/cgf.pdf [2] https://github.com/mark-poscablo/gpu-radix-sort
| """ | |
| Pseudo-CUDA implementation of radix sort local block sort. | |
| This shows how the ACTUAL GPU algorithm works at the thread level, | |
| with masks and parallel prefix scans - no sequential counting! | |
| Key insight: Instead of a sequential loop with `counter += 1`, | |
| each thread discovers its position via parallel operations. | |
| """ | |
| from typing import List, Tuple | |
| from dataclasses import dataclass | |
| # ============================================================================= | |
| # Simulated CUDA Environment | |
| # ============================================================================= | |
| @dataclass | |
| class ThreadContext: | |
| """Simulates CUDA thread-local state.""" | |
| thread_id: int | |
| block_size: int | |
| # Thread-private variables (each thread has its own copy) | |
| my_value: int = 0 | |
| my_digit: int = 0 | |
| my_position: int = 0 # Position within my digit group | |
| class SharedMemory: | |
| """Simulates CUDA shared memory (visible to all threads in block).""" | |
| def __init__(self, size: int): | |
| self.data = [0] * size | |
| def __getitem__(self, idx: int) -> int: | |
| return self.data[idx] | |
| def __setitem__(self, idx: int, value: int): | |
| self.data[idx] = value | |
| class BlockSimulator: | |
| """ | |
| Simulates a CUDA block with multiple threads. | |
| In real CUDA: | |
| - All threads execute the SAME code simultaneously | |
| - __syncthreads() makes all threads wait for each other | |
| - Shared memory is visible to all threads in the block | |
| """ | |
| def __init__(self, block_size: int): | |
| self.block_size = block_size | |
| self.threads = [ThreadContext(i, block_size) for i in range(block_size)] | |
| # Shared memory arrays | |
| self.smem_mask = SharedMemory(block_size) | |
| self.smem_scan = SharedMemory(block_size + 1) # +1 for exclusive scan | |
| self.smem_digit_offsets = SharedMemory(4) | |
| self.smem_digit_counts = SharedMemory(4) | |
| self.smem_output = SharedMemory(block_size) | |
| self.smem_output_prefix = SharedMemory(block_size) | |
| def barrier(self): | |
| """ | |
| Simulates __syncthreads(). | |
| In real CUDA, this makes all threads wait until everyone reaches this point. | |
| In our simulation, we just note where barriers happen. | |
| """ | |
| pass # In simulation, we execute threads in lockstep anyway | |
| def run_kernel(self, kernel_func, *args): | |
| """ | |
| Run a kernel function for all threads. | |
| In real CUDA: all threads run simultaneously. | |
| In simulation: we run them one at a time, but the algorithm | |
| is designed so order doesn't matter between barriers. | |
| """ | |
| for thread in self.threads: | |
| kernel_func(thread, self, *args) | |
| # ============================================================================= | |
| # Parallel Primitives | |
| # ============================================================================= | |
| def parallel_exclusive_scan_simulation(block: BlockSimulator, input_smem: SharedMemory, | |
| output_smem: SharedMemory, n: int): | |
| """ | |
| Simulates a parallel exclusive prefix scan. | |
| In real CUDA, this would be done with a tree-based algorithm | |
| where all threads participate simultaneously. | |
| Input: [0, 0, 0, 1, 0, 0, 0, 1] | |
| Output: [0, 0, 0, 0, 1, 1, 1, 1] | |
| Each thread reads from input, writes to output. | |
| """ | |
| # For simulation, just do it sequentially | |
| # (In real GPU, this is O(log n) parallel steps) | |
| running_sum = 0 | |
| for i in range(n): | |
| output_smem[i] = running_sum | |
| running_sum += input_smem[i] | |
| output_smem[n] = running_sum # Total sum at the end | |
| # ============================================================================= | |
| # The Actual CUDA Kernel (Pseudo-code as Python) | |
| # ============================================================================= | |
| def extract_digit(value: int, shift: int) -> int: | |
| return (value >> shift) & 0b11 | |
| def kernel_phase1_load(thread: ThreadContext, block: BlockSimulator, | |
| input_data: List[int], shift_width: int): | |
| """ | |
| Step 1: Each thread loads its element and extracts its digit. | |
| CUDA equivalent: | |
| __global__ void load_kernel(uint* input, int shift) { | |
| int tid = threadIdx.x; | |
| my_value = input[tid]; | |
| my_digit = (my_value >> shift) & 0x3; | |
| } | |
| """ | |
| tid = thread.thread_id | |
| thread.my_value = input_data[tid] | |
| thread.my_digit = extract_digit(thread.my_value, shift_width) | |
| def kernel_phase2_build_mask_and_scan(thread: ThreadContext, block: BlockSimulator, | |
| target_digit: int): | |
| """ | |
| Step 2: Build mask for one digit and scan it. | |
| This is the KEY parallel operation that replaces sequential counting! | |
| CUDA equivalent: | |
| __global__ void build_mask_and_scan(int target_digit) { | |
| int tid = threadIdx.x; | |
| // Each thread: "Do I have this digit?" | |
| smem_mask[tid] = (my_digit == target_digit) ? 1 : 0; | |
| __syncthreads(); | |
| // Parallel prefix scan (in practice, use CUB or custom implementation) | |
| smem_scan[tid] = parallel_exclusive_scan(smem_mask)[tid]; | |
| __syncthreads(); | |
| } | |
| """ | |
| tid = thread.thread_id | |
| # Each thread writes to its own slot - NO CONFLICTS | |
| if thread.my_digit == target_digit: | |
| block.smem_mask[tid] = 1 | |
| else: | |
| block.smem_mask[tid] = 0 | |
| def kernel_phase3_save_position(thread: ThreadContext, block: BlockSimulator, | |
| target_digit: int): | |
| """ | |
| Step 3: Each thread with the target digit reads its position from the scan. | |
| CUDA equivalent: | |
| __global__ void save_position(int target_digit) { | |
| int tid = threadIdx.x; | |
| if (my_digit == target_digit) { | |
| my_position = smem_scan[tid]; | |
| } | |
| } | |
| """ | |
| tid = thread.thread_id | |
| if thread.my_digit == target_digit: | |
| # The scan result tells me: "How many elements with this digit are before me?" | |
| thread.my_position = block.smem_scan[tid] | |
| def kernel_phase4_compute_digit_counts(thread: ThreadContext, block: BlockSimulator): | |
| """ | |
| Step 4: Compute how many of each digit we have, and where each group starts. | |
| Only need 4 threads for this (or thread 0 does it all). | |
| """ | |
| tid = thread.thread_id | |
| if tid < 4: | |
| # Count is at the end of the scan for each digit | |
| # We stored this during the scan phase | |
| count = block.smem_digit_counts[tid] | |
| block.smem_digit_counts[tid] = count | |
| def kernel_phase5_scatter(thread: ThreadContext, block: BlockSimulator): | |
| """ | |
| Step 5: Each thread writes its element to the correct sorted position. | |
| CUDA equivalent: | |
| __global__ void scatter() { | |
| int tid = threadIdx.x; | |
| int new_pos = digit_offsets[my_digit] + my_position; | |
| output[new_pos] = my_value; | |
| output_prefix[new_pos] = my_position; | |
| } | |
| """ | |
| tid = thread.thread_id | |
| digit = thread.my_digit | |
| digit_offset = block.smem_digit_offsets[digit] | |
| new_pos = digit_offset + thread.my_position | |
| # Each thread writes to a DIFFERENT position - NO CONFLICTS | |
| block.smem_output[new_pos] = thread.my_value | |
| block.smem_output_prefix[new_pos] = thread.my_position | |
| # ============================================================================= | |
| # Complete Local Sort Simulation | |
| # ============================================================================= | |
| def cuda_local_block_sort(input_data: List[int], shift_width: int) -> Tuple[List[int], List[int], List[int]]: | |
| """ | |
| Simulate the full CUDA local block sort. | |
| Shows step-by-step what happens in a real GPU. | |
| """ | |
| n = len(input_data) | |
| block = BlockSimulator(n) | |
| print(f"\n{'='*70}") | |
| print(f"CUDA LOCAL BLOCK SORT SIMULATION") | |
| print(f"Input: {input_data}") | |
| print(f"Shift: {shift_width} (looking at bits {shift_width}-{shift_width+1})") | |
| print(f"{'='*70}") | |
| # ========================================================================= | |
| # STEP 1: Load data (all threads in parallel) | |
| # ========================================================================= | |
| print(f"\n[STEP 1] All {n} threads load their element and extract digit") | |
| block.run_kernel(kernel_phase1_load, input_data, shift_width) | |
| print(f" Thread ID: {[t.thread_id for t in block.threads]}") | |
| print(f" my_value: {[t.my_value for t in block.threads]}") | |
| print(f" my_digit: {[t.my_digit for t in block.threads]}") | |
| # ========================================================================= | |
| # STEP 2-3: For each digit, build mask and scan (this is the key!) | |
| # ========================================================================= | |
| digit_counts = [0, 0, 0, 0] | |
| for target_digit in range(4): | |
| print(f"\n[STEP 2-3] Processing digit {target_digit}") | |
| # All threads build mask in parallel | |
| print(f" All threads: 'Do I have digit {target_digit}?'") | |
| block.run_kernel(kernel_phase2_build_mask_and_scan, target_digit) | |
| mask = [block.smem_mask[i] for i in range(n)] | |
| print(f" Mask: {mask}") | |
| # Barrier - wait for all threads | |
| block.barrier() | |
| # Parallel prefix scan (simulated) | |
| parallel_exclusive_scan_simulation(block, block.smem_mask, block.smem_scan, n) | |
| scan_result = [block.smem_scan[i] for i in range(n)] | |
| total = block.smem_scan[n] | |
| print(f" Scan result: {scan_result} (total: {total})") | |
| digit_counts[target_digit] = total | |
| # Barrier | |
| block.barrier() | |
| # Each thread with this digit saves its position | |
| block.run_kernel(kernel_phase3_save_position, target_digit) | |
| positions = [t.my_position if t.my_digit == target_digit else '-' | |
| for t in block.threads] | |
| print(f" Positions: {positions}") | |
| print(f"\n[STEP 4] Compute digit offsets") | |
| print(f" Digit counts: {digit_counts}") | |
| # Exclusive scan of digit counts to get offsets | |
| digit_offsets = [0, 0, 0, 0] | |
| running = 0 | |
| for d in range(4): | |
| digit_offsets[d] = running | |
| block.smem_digit_offsets[d] = running | |
| running += digit_counts[d] | |
| print(f" Digit offsets: {digit_offsets}") | |
| # ========================================================================= | |
| # STEP 5: Scatter (all threads in parallel) | |
| # ========================================================================= | |
| print(f"\n[STEP 5] All threads scatter to sorted positions") | |
| block.run_kernel(kernel_phase5_scatter) | |
| for t in block.threads: | |
| new_pos = digit_offsets[t.my_digit] + t.my_position | |
| print(f" Thread {t.thread_id}: value={t.my_value}, digit={t.my_digit}, " | |
| f"offset={digit_offsets[t.my_digit]}, position={t.my_position} " | |
| f"→ output[{new_pos}]") | |
| # ========================================================================= | |
| # Extract results | |
| # ========================================================================= | |
| sorted_output = [block.smem_output[i] for i in range(n)] | |
| sorted_prefix = [block.smem_output_prefix[i] for i in range(n)] | |
| print(f"\n[RESULT]") | |
| print(f" Sorted output: {sorted_output}") | |
| print(f" Local prefixes: {sorted_prefix}") | |
| print(f" Digit counts: {digit_counts}") | |
| return sorted_output, sorted_prefix, digit_counts | |
| # ============================================================================= | |
| # Why This Is Parallel (No Sequential Counting!) | |
| # ============================================================================= | |
| def explain_parallelism(): | |
| """ | |
| Explain why the mask-and-scan approach is truly parallel. | |
| """ | |
| print(""" | |
| ╔══════════════════════════════════════════════════════════════════════════════╗ | |
| ║ WHY MASK-AND-SCAN IS PARALLEL ║ | |
| ╠══════════════════════════════════════════════════════════════════════════════╣ | |
| ║ ║ | |
| ║ SEQUENTIAL APPROACH (what we had before): ║ | |
| ║ ───────────────────────────────────────── ║ | |
| ║ ║ | |
| ║ counter = 0 ║ | |
| ║ for i in range(n): # Must be sequential! ║ | |
| ║ position[i] = counter # Read counter ║ | |
| ║ counter += 1 # Modify counter ║ | |
| ║ # Next iteration depends on this! ║ | |
| ║ ║ | |
| ║ Problem: Iteration i+1 needs the result of iteration i. ║ | |
| ║ ║ | |
| ║ ──────────────────────────────────────────────────────────────────────── ║ | |
| ║ ║ | |
| ║ PARALLEL APPROACH (mask and scan): ║ | |
| ║ ────────────────────────────────── ║ | |
| ║ ║ | |
| ║ Step 1: Build mask (ALL THREADS SIMULTANEOUSLY) ║ | |
| ║ ┌─────────────────────────────────────────────────────────┐ ║ | |
| ║ │ Thread 0: mask[0] = (my_digit == target) ? 1 : 0 │ ║ | |
| ║ │ Thread 1: mask[1] = (my_digit == target) ? 1 : 0 │ PARALLEL ║ | |
| ║ │ Thread 2: mask[2] = (my_digit == target) ? 1 : 0 │ (no deps) ║ | |
| ║ │ ... │ ║ | |
| ║ └─────────────────────────────────────────────────────────┘ ║ | |
| ║ ║ | |
| ║ Step 2: Parallel prefix scan (O(log n) steps, all threads active) ║ | |
| ║ ┌─────────────────────────────────────────────────────────┐ ║ | |
| ║ │ Input: [0, 0, 0, 1, 0, 0, 0, 1] │ ║ | |
| ║ │ ↓ (log₂(8) = 3 parallel steps) │ ║ | |
| ║ │ Output: [0, 0, 0, 0, 1, 1, 1, 1] │ ║ | |
| ║ └─────────────────────────────────────────────────────────┘ ║ | |
| ║ ║ | |
| ║ Step 3: Read position (ALL THREADS SIMULTANEOUSLY) ║ | |
| ║ ┌─────────────────────────────────────────────────────────┐ ║ | |
| ║ │ Thread 3: my_position = scan[3] = 0 "I'm 1st!" │ ║ | |
| ║ │ Thread 7: my_position = scan[7] = 1 "I'm 2nd!" │ PARALLEL ║ | |
| ║ │ (Other threads don't have this digit, they skip) │ (no deps) ║ | |
| ║ └─────────────────────────────────────────────────────────┘ ║ | |
| ║ ║ | |
| ║ No thread depends on another thread's result! ║ | |
| ║ Each thread discovers its position INDEPENDENTLY. ║ | |
| ║ ║ | |
| ╚══════════════════════════════════════════════════════════════════════════════╝ | |
| """) | |
| # ============================================================================= | |
| # Test and Demo | |
| # ============================================================================= | |
| def test_cuda_sort(): | |
| """Test the CUDA-style sort matches expected output.""" | |
| input_data = [7, 2, 5, 0, 3, 6, 1, 4] | |
| sorted_out, prefix_out, counts = cuda_local_block_sort(input_data, shift_width=0) | |
| expected_sorted = [0, 4, 5, 1, 2, 6, 7, 3] | |
| expected_counts = [2, 2, 2, 2] | |
| assert sorted_out == expected_sorted, f"Expected {expected_sorted}, got {sorted_out}" | |
| assert counts == expected_counts, f"Expected {expected_counts}, got {counts}" | |
| print("\n✓ Test passed!") | |
| if __name__ == "__main__": | |
| explain_parallelism() | |
| test_cuda_sort() |
| """ | |
| Full Pseudo-CUDA implementation of radix sort with MULTIPLE BLOCKS. | |
| Shows all three phases: | |
| 1. Local block sort (each block sorts its portion) | |
| 2. Scan block sums (compute global offsets) | |
| 3. Global shuffle (scatter to final positions) | |
| Using 16 elements split into 2 blocks of 8. | |
| """ | |
| from typing import List, Tuple, Dict | |
| from dataclasses import dataclass, field | |
| # ============================================================================= | |
| # Simulated GPU Memory | |
| # ============================================================================= | |
| class GlobalMemory: | |
| """ | |
| Simulates GPU global memory - accessible by ALL blocks. | |
| This is like cudaMalloc'd memory. | |
| """ | |
| def __init__(self): | |
| self.buffers: Dict[str, List[int]] = {} | |
| def allocate(self, name: str, size: int, init_value: int = 0): | |
| self.buffers[name] = [init_value] * size | |
| def read(self, name: str, idx: int) -> int: | |
| return self.buffers[name][idx] | |
| def write(self, name: str, idx: int, value: int): | |
| self.buffers[name][idx] = value | |
| def get_buffer(self, name: str) -> List[int]: | |
| return self.buffers[name] | |
| class SharedMemory: | |
| """Simulates CUDA shared memory (only visible within one block).""" | |
| def __init__(self, size: int): | |
| self.data = [0] * size | |
| def __getitem__(self, idx: int) -> int: | |
| return self.data[idx] | |
| def __setitem__(self, idx: int, value: int): | |
| self.data[idx] = value | |
| @dataclass | |
| class ThreadContext: | |
| """Per-thread state.""" | |
| thread_id: int # threadIdx.x (local to block) | |
| block_id: int # blockIdx.x | |
| global_id: int # blockIdx.x * blockDim.x + threadIdx.x | |
| # Thread-private variables | |
| my_value: int = 0 | |
| my_digit: int = 0 | |
| my_local_position: int = 0 # Position within digit group IN THIS BLOCK | |
| my_global_position: int = 0 # Final position in output array | |
| # ============================================================================= | |
| # Helper Functions | |
| # ============================================================================= | |
| def extract_digit(value: int, shift: int) -> int: | |
| return (value >> shift) & 0b11 | |
| def exclusive_scan(arr: List[int]) -> List[int]: | |
| """Exclusive prefix sum.""" | |
| result = [0] * len(arr) | |
| running = 0 | |
| for i in range(len(arr)): | |
| result[i] = running | |
| running += arr[i] | |
| return result | |
| # ============================================================================= | |
| # PHASE 1: Local Block Sort (runs on each block independently) | |
| # ============================================================================= | |
| def phase1_local_sort( | |
| block_id: int, | |
| block_size: int, | |
| global_mem: GlobalMemory, | |
| shift_width: int | |
| ) -> Tuple[List[int], List[int], List[int]]: | |
| """ | |
| Each block sorts its portion of the input. | |
| Returns: | |
| sorted_data: Locally sorted elements | |
| local_prefix: Each element's position within its digit group | |
| digit_counts: [count_0, count_1, count_2, count_3] for this block | |
| """ | |
| print(f"\n [Block {block_id}] Local sort") | |
| # Create threads for this block | |
| threads = [ | |
| ThreadContext( | |
| thread_id=t, | |
| block_id=block_id, | |
| global_id=block_id * block_size + t | |
| ) | |
| for t in range(block_size) | |
| ] | |
| # Shared memory for this block | |
| smem_mask = [0] * block_size | |
| smem_scan = [0] * (block_size + 1) | |
| # Step 1: Each thread loads its element | |
| input_buffer = global_mem.get_buffer("input") | |
| for t in threads: | |
| t.my_value = input_buffer[t.global_id] | |
| t.my_digit = extract_digit(t.my_value, shift_width) | |
| print(f" Thread IDs: {[t.thread_id for t in threads]}") | |
| print(f" Values: {[t.my_value for t in threads]}") | |
| print(f" Digits: {[t.my_digit for t in threads]}") | |
| # Step 2-3: Mask and scan for each digit | |
| digit_counts = [0, 0, 0, 0] | |
| for target_digit in range(4): | |
| # Build mask (parallel) | |
| for t in threads: | |
| smem_mask[t.thread_id] = 1 if t.my_digit == target_digit else 0 | |
| # Exclusive scan (parallel in real CUDA) | |
| running = 0 | |
| for i in range(block_size): | |
| smem_scan[i] = running | |
| running += smem_mask[i] | |
| smem_scan[block_size] = running # Total | |
| digit_counts[target_digit] = running | |
| # Each thread with this digit saves its position | |
| for t in threads: | |
| if t.my_digit == target_digit: | |
| t.my_local_position = smem_scan[t.thread_id] | |
| print(f" Digit counts: {digit_counts}") | |
| # Step 4: Compute digit offsets within block | |
| digit_offsets = exclusive_scan(digit_counts) | |
| # Step 5: Scatter to sorted positions (within block) | |
| sorted_data = [0] * block_size | |
| local_prefix = [0] * block_size | |
| for t in threads: | |
| new_pos = digit_offsets[t.my_digit] + t.my_local_position | |
| sorted_data[new_pos] = t.my_value | |
| local_prefix[new_pos] = t.my_local_position | |
| print(f" Sorted: {sorted_data}") | |
| print(f" Local prefix:{local_prefix}") | |
| return sorted_data, local_prefix, digit_counts | |
| # ============================================================================= | |
| # PHASE 2: Scan Block Sums (compute global offsets) | |
| # ============================================================================= | |
| def phase2_scan_block_sums( | |
| all_block_counts: List[List[int]], # [block][digit] | |
| num_blocks: int | |
| ) -> List[List[int]]: | |
| """ | |
| Compute global starting position for each (digit, block) pair. | |
| In CUDA, this could be done by: | |
| - A separate kernel launch | |
| - Or integrated with Phase 3 | |
| The key insight: we need to know where EACH block's elements | |
| for EACH digit should go in the final output. | |
| Returns: | |
| global_offsets[block][digit] = starting position in output | |
| """ | |
| print(f"\n{'='*70}") | |
| print("PHASE 2: Scan Block Sums") | |
| print(f"{'='*70}") | |
| # Flatten: all digit-0 counts, then all digit-1 counts, etc. | |
| # This is how the CUDA algorithm typically lays it out: | |
| # [block0_d0, block1_d0, block0_d1, block1_d1, ...] | |
| print("\n Block counts per digit:") | |
| for d in range(4): | |
| counts = [all_block_counts[b][d] for b in range(num_blocks)] | |
| print(f" Digit {d}: {counts}") | |
| # Compute global offsets | |
| # Order: all digit 0s first, then digit 1s, then 2s, then 3s | |
| global_offsets = [[0] * 4 for _ in range(num_blocks)] | |
| running = 0 | |
| for digit in range(4): | |
| for block in range(num_blocks): | |
| global_offsets[block][digit] = running | |
| running += all_block_counts[block][digit] | |
| print("\n Global offsets (where each block's digit group starts):") | |
| for d in range(4): | |
| offsets = [global_offsets[b][d] for b in range(num_blocks)] | |
| print(f" Digit {d}: blocks start at {offsets}") | |
| return global_offsets | |
| # ============================================================================= | |
| # PHASE 3: Global Shuffle (scatter to final positions) | |
| # ============================================================================= | |
| def phase3_global_shuffle( | |
| block_id: int, | |
| block_size: int, | |
| sorted_data: List[int], | |
| local_prefix: List[int], | |
| global_offsets: List[List[int]], | |
| global_mem: GlobalMemory, | |
| shift_width: int | |
| ): | |
| """ | |
| Each block scatters its elements to final global positions. | |
| global_position = global_offsets[block][digit] + local_prefix | |
| In CUDA, ALL blocks run this in parallel. | |
| """ | |
| print(f"\n [Block {block_id}] Global shuffle") | |
| for i in range(block_size): | |
| value = sorted_data[i] | |
| digit = extract_digit(value, shift_width) | |
| local_pos = local_prefix[i] | |
| global_pos = global_offsets[block_id][digit] + local_pos | |
| global_mem.write("output", global_pos, value) | |
| print(f" Element {i}: value={value}, digit={digit}, " | |
| f"global_offset={global_offsets[block_id][digit]}, " | |
| f"local_pos={local_pos} → output[{global_pos}]") | |
| # ============================================================================= | |
| # Complete Multi-Block Radix Sort Pass | |
| # ============================================================================= | |
| def radix_sort_pass_multiblock( | |
| input_data: List[int], | |
| block_size: int, | |
| shift_width: int | |
| ) -> List[int]: | |
| """ | |
| One complete pass of radix sort across multiple blocks. | |
| """ | |
| n = len(input_data) | |
| num_blocks = (n + block_size - 1) // block_size | |
| # Pad input to multiple of block_size | |
| padded_input = input_data + [0xFFFFFFFF] * (num_blocks * block_size - n) | |
| # Initialize global memory | |
| global_mem = GlobalMemory() | |
| global_mem.allocate("input", len(padded_input)) | |
| global_mem.allocate("output", len(padded_input)) | |
| for i, v in enumerate(padded_input): | |
| global_mem.write("input", i, v) | |
| print(f"\n{'#'*70}") | |
| print(f"# RADIX SORT PASS (shift={shift_width}, looking at bits {shift_width}-{shift_width+1})") | |
| print(f"# Input: {input_data}") | |
| print(f"# {num_blocks} blocks of {block_size} elements each") | |
| print(f"{'#'*70}") | |
| # ========================================================================= | |
| # PHASE 1: Local Block Sort (all blocks run in parallel on GPU) | |
| # ========================================================================= | |
| print(f"\n{'='*70}") | |
| print("PHASE 1: Local Block Sort (each block independently)") | |
| print(f"{'='*70}") | |
| all_sorted_blocks = [] | |
| all_local_prefixes = [] | |
| all_block_counts = [] | |
| for block_id in range(num_blocks): | |
| sorted_data, local_prefix, digit_counts = phase1_local_sort( | |
| block_id, block_size, global_mem, shift_width | |
| ) | |
| all_sorted_blocks.append(sorted_data) | |
| all_local_prefixes.append(local_prefix) | |
| all_block_counts.append(digit_counts) | |
| # ========================================================================= | |
| # PHASE 2: Scan Block Sums | |
| # ========================================================================= | |
| global_offsets = phase2_scan_block_sums(all_block_counts, num_blocks) | |
| # ========================================================================= | |
| # PHASE 3: Global Shuffle (all blocks run in parallel on GPU) | |
| # ========================================================================= | |
| print(f"\n{'='*70}") | |
| print("PHASE 3: Global Shuffle (each block independently)") | |
| print(f"{'='*70}") | |
| for block_id in range(num_blocks): | |
| phase3_global_shuffle( | |
| block_id, block_size, | |
| all_sorted_blocks[block_id], | |
| all_local_prefixes[block_id], | |
| global_offsets, | |
| global_mem, | |
| shift_width | |
| ) | |
| # ========================================================================= | |
| # Result | |
| # ========================================================================= | |
| output = global_mem.get_buffer("output")[:n] | |
| print(f"\n{'='*70}") | |
| print("RESULT") | |
| print(f"{'='*70}") | |
| print(f" Output: {output}") | |
| return output | |
| # ============================================================================= | |
| # Visualization of the Full Algorithm | |
| # ============================================================================= | |
| def visualize_algorithm(): | |
| """ASCII art explanation of the multi-block algorithm.""" | |
| print(""" | |
| ╔══════════════════════════════════════════════════════════════════════════════╗ | |
| ║ MULTI-BLOCK RADIX SORT OVERVIEW ║ | |
| ╠══════════════════════════════════════════════════════════════════════════════╣ | |
| ║ ║ | |
| ║ INPUT: 16 elements, 2 blocks of 8 ║ | |
| ║ ┌─────────────────────────────┬─────────────────────────────┐ ║ | |
| ║ │ Block 0 (elements 0-7) │ Block 1 (elements 8-15)│ ║ | |
| ║ │ [7, 2, 5, 0, 3, 6, 1, 4] │ [10, 9, 8, 11, 14, 13, 12, 15] │ ║ | |
| ║ └─────────────────────────────┴─────────────────────────────┘ ║ | |
| ║ │ │ ║ | |
| ║ ▼ ▼ ║ | |
| ║ ════════════════════════════════════════════════════════════════════════ ║ | |
| ║ PHASE 1: Local Sort (blocks work independently, IN PARALLEL) ║ | |
| ║ ════════════════════════════════════════════════════════════════════════ ║ | |
| ║ │ │ ║ | |
| ║ ▼ ▼ ║ | |
| ║ ┌─────────────────────────────┬─────────────────────────────┐ ║ | |
| ║ │ Sorted: [0,4, 5,1, 2,6, 7,3]│ Sorted: [8,12, 9,13, 10,14, 11,15]│ ║ | |
| ║ │ Counts: [2,2,2,2] │ Counts: [2,2,2,2] │ ║ | |
| ║ │ d0 d1 d2 d3 │ d0 d1 d2 d3 │ ║ | |
| ║ └─────────────────────────────┴─────────────────────────────┘ ║ | |
| ║ │ │ ║ | |
| ║ └──────────────┬───────────────┘ ║ | |
| ║ ▼ ║ | |
| ║ ════════════════════════════════════════════════════════════════════════ ║ | |
| ║ PHASE 2: Scan Block Sums (compute global offsets) ║ | |
| ║ ════════════════════════════════════════════════════════════════════════ ║ | |
| ║ ║ | |
| ║ Digit 0: Block0 has 2, Block1 has 2 ║ | |
| ║ Digit 1: Block0 has 2, Block1 has 2 ║ | |
| ║ Digit 2: Block0 has 2, Block1 has 2 ║ | |
| ║ Digit 3: Block0 has 2, Block1 has 2 ║ | |
| ║ ║ | |
| ║ Global offsets (where does each block's digit group go?): ║ | |
| ║ ┌────────┬─────────┬─────────┐ ║ | |
| ║ │ Digit │ Block 0 │ Block 1 │ ║ | |
| ║ ├────────┼─────────┼─────────┤ ║ | |
| ║ │ 0 │ 0 │ 2 │ ← Digit 0: positions 0-3 ║ | |
| ║ │ 1 │ 4 │ 6 │ ← Digit 1: positions 4-7 ║ | |
| ║ │ 2 │ 8 │ 10 │ ← Digit 2: positions 8-11 ║ | |
| ║ │ 3 │ 12 │ 14 │ ← Digit 3: positions 12-15 ║ | |
| ║ └────────┴─────────┴─────────┘ ║ | |
| ║ ║ | |
| ║ ▼ ║ | |
| ║ ════════════════════════════════════════════════════════════════════════ ║ | |
| ║ PHASE 3: Global Shuffle (blocks work independently, IN PARALLEL) ║ | |
| ║ ════════════════════════════════════════════════════════════════════════ ║ | |
| ║ ║ | |
| ║ Block 0: value=0 (digit 0) → offset[0][0]=0 + local_pos=0 → output[0] ║ | |
| ║ Block 0: value=4 (digit 0) → offset[0][0]=0 + local_pos=1 → output[1] ║ | |
| ║ Block 1: value=8 (digit 0) → offset[1][0]=2 + local_pos=0 → output[2] ║ | |
| ║ Block 1: value=12(digit 0) → offset[1][0]=2 + local_pos=1 → output[3] ║ | |
| ║ ... and so on for digits 1, 2, 3 ... ║ | |
| ║ ║ | |
| ║ ▼ ║ | |
| ║ ┌───────────────────────────────────────────────────────────────────────┐ ║ | |
| ║ │ OUTPUT (sorted by 2-bit digit): │ ║ | |
| ║ │ [0, 4, 8, 12, 5, 1, 9, 13, 2, 6, 10, 14, 7, 3, 11, 15] │ ║ | |
| ║ │ └─ digit 0 ─┘ └─ digit 1 ─┘ └─ digit 2 ─┘ └─ digit 3 ─┘ │ ║ | |
| ║ └───────────────────────────────────────────────────────────────────────┘ ║ | |
| ║ ║ | |
| ║ After 16 passes (2 bits × 16 = 32 bits), array is fully sorted! ║ | |
| ║ ║ | |
| ╚══════════════════════════════════════════════════════════════════════════════╝ | |
| """) | |
| # ============================================================================= | |
| # Main Demo | |
| # ============================================================================= | |
| def main(): | |
| visualize_algorithm() | |
| # 16 elements, 2 blocks of 8 | |
| input_data = [7, 2, 5, 0, 3, 6, 1, 4, 10, 9, 8, 11, 14, 13, 12, 15] | |
| print("\n" + "="*70) | |
| print("RUNNING SIMULATION WITH ACTUAL DATA") | |
| print("="*70) | |
| result = radix_sort_pass_multiblock( | |
| input_data=input_data, | |
| block_size=8, | |
| shift_width=0 # Look at bits 0-1 | |
| ) | |
| # Verify: after one pass, elements should be grouped by their 2-bit digit | |
| print(f"\n{'='*70}") | |
| print("VERIFICATION") | |
| print(f"{'='*70}") | |
| digits = [extract_digit(v, 0) for v in result] | |
| print(f" Result digits: {digits}") | |
| print(f" Should be: [0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3]") | |
| expected_digits = [0]*4 + [1]*4 + [2]*4 + [3]*4 | |
| if digits == expected_digits: | |
| print(" ✓ Correct! Elements are grouped by digit.") | |
| else: | |
| print(" ✗ Something went wrong!") | |
| if __name__ == "__main__": | |
| main() |