Skip to content

Instantly share code, notes, and snippets.

@thomasantony
Last active December 28, 2025 23:27
Show Gist options
  • Select an option

  • Save thomasantony/6e94b08ec53709b1a269457fe854dd92 to your computer and use it in GitHub Desktop.

Select an option

Save thomasantony/6e94b08ec53709b1a269457fe854dd92 to your computer and use it in GitHub Desktop.
Parallel Radix Sort
"""
Pseudo-CUDA implementation of radix sort local block sort.
This shows how the ACTUAL GPU algorithm works at the thread level,
with masks and parallel prefix scans - no sequential counting!
Key insight: Instead of a sequential loop with `counter += 1`,
each thread discovers its position via parallel operations.
"""
from typing import List, Tuple
from dataclasses import dataclass
# =============================================================================
# Simulated CUDA Environment
# =============================================================================
@dataclass
class ThreadContext:
"""Simulates CUDA thread-local state."""
thread_id: int
block_size: int
# Thread-private variables (each thread has its own copy)
my_value: int = 0
my_digit: int = 0
my_position: int = 0 # Position within my digit group
class SharedMemory:
"""Simulates CUDA shared memory (visible to all threads in block)."""
def __init__(self, size: int):
self.data = [0] * size
def __getitem__(self, idx: int) -> int:
return self.data[idx]
def __setitem__(self, idx: int, value: int):
self.data[idx] = value
class BlockSimulator:
"""
Simulates a CUDA block with multiple threads.
In real CUDA:
- All threads execute the SAME code simultaneously
- __syncthreads() makes all threads wait for each other
- Shared memory is visible to all threads in the block
"""
def __init__(self, block_size: int):
self.block_size = block_size
self.threads = [ThreadContext(i, block_size) for i in range(block_size)]
# Shared memory arrays
self.smem_mask = SharedMemory(block_size)
self.smem_scan = SharedMemory(block_size + 1) # +1 for exclusive scan
self.smem_digit_offsets = SharedMemory(4)
self.smem_digit_counts = SharedMemory(4)
self.smem_output = SharedMemory(block_size)
self.smem_output_prefix = SharedMemory(block_size)
def barrier(self):
"""
Simulates __syncthreads().
In real CUDA, this makes all threads wait until everyone reaches this point.
In our simulation, we just note where barriers happen.
"""
pass # In simulation, we execute threads in lockstep anyway
def run_kernel(self, kernel_func, *args):
"""
Run a kernel function for all threads.
In real CUDA: all threads run simultaneously.
In simulation: we run them one at a time, but the algorithm
is designed so order doesn't matter between barriers.
"""
for thread in self.threads:
kernel_func(thread, self, *args)
# =============================================================================
# Parallel Primitives
# =============================================================================
def parallel_exclusive_scan_simulation(block: BlockSimulator, input_smem: SharedMemory,
output_smem: SharedMemory, n: int):
"""
Simulates a parallel exclusive prefix scan.
In real CUDA, this would be done with a tree-based algorithm
where all threads participate simultaneously.
Input: [0, 0, 0, 1, 0, 0, 0, 1]
Output: [0, 0, 0, 0, 1, 1, 1, 1]
Each thread reads from input, writes to output.
"""
# For simulation, just do it sequentially
# (In real GPU, this is O(log n) parallel steps)
running_sum = 0
for i in range(n):
output_smem[i] = running_sum
running_sum += input_smem[i]
output_smem[n] = running_sum # Total sum at the end
# =============================================================================
# The Actual CUDA Kernel (Pseudo-code as Python)
# =============================================================================
def extract_digit(value: int, shift: int) -> int:
return (value >> shift) & 0b11
def kernel_phase1_load(thread: ThreadContext, block: BlockSimulator,
input_data: List[int], shift_width: int):
"""
Step 1: Each thread loads its element and extracts its digit.
CUDA equivalent:
__global__ void load_kernel(uint* input, int shift) {
int tid = threadIdx.x;
my_value = input[tid];
my_digit = (my_value >> shift) & 0x3;
}
"""
tid = thread.thread_id
thread.my_value = input_data[tid]
thread.my_digit = extract_digit(thread.my_value, shift_width)
def kernel_phase2_build_mask_and_scan(thread: ThreadContext, block: BlockSimulator,
target_digit: int):
"""
Step 2: Build mask for one digit and scan it.
This is the KEY parallel operation that replaces sequential counting!
CUDA equivalent:
__global__ void build_mask_and_scan(int target_digit) {
int tid = threadIdx.x;
// Each thread: "Do I have this digit?"
smem_mask[tid] = (my_digit == target_digit) ? 1 : 0;
__syncthreads();
// Parallel prefix scan (in practice, use CUB or custom implementation)
smem_scan[tid] = parallel_exclusive_scan(smem_mask)[tid];
__syncthreads();
}
"""
tid = thread.thread_id
# Each thread writes to its own slot - NO CONFLICTS
if thread.my_digit == target_digit:
block.smem_mask[tid] = 1
else:
block.smem_mask[tid] = 0
def kernel_phase3_save_position(thread: ThreadContext, block: BlockSimulator,
target_digit: int):
"""
Step 3: Each thread with the target digit reads its position from the scan.
CUDA equivalent:
__global__ void save_position(int target_digit) {
int tid = threadIdx.x;
if (my_digit == target_digit) {
my_position = smem_scan[tid];
}
}
"""
tid = thread.thread_id
if thread.my_digit == target_digit:
# The scan result tells me: "How many elements with this digit are before me?"
thread.my_position = block.smem_scan[tid]
def kernel_phase4_compute_digit_counts(thread: ThreadContext, block: BlockSimulator):
"""
Step 4: Compute how many of each digit we have, and where each group starts.
Only need 4 threads for this (or thread 0 does it all).
"""
tid = thread.thread_id
if tid < 4:
# Count is at the end of the scan for each digit
# We stored this during the scan phase
count = block.smem_digit_counts[tid]
block.smem_digit_counts[tid] = count
def kernel_phase5_scatter(thread: ThreadContext, block: BlockSimulator):
"""
Step 5: Each thread writes its element to the correct sorted position.
CUDA equivalent:
__global__ void scatter() {
int tid = threadIdx.x;
int new_pos = digit_offsets[my_digit] + my_position;
output[new_pos] = my_value;
output_prefix[new_pos] = my_position;
}
"""
tid = thread.thread_id
digit = thread.my_digit
digit_offset = block.smem_digit_offsets[digit]
new_pos = digit_offset + thread.my_position
# Each thread writes to a DIFFERENT position - NO CONFLICTS
block.smem_output[new_pos] = thread.my_value
block.smem_output_prefix[new_pos] = thread.my_position
# =============================================================================
# Complete Local Sort Simulation
# =============================================================================
def cuda_local_block_sort(input_data: List[int], shift_width: int) -> Tuple[List[int], List[int], List[int]]:
"""
Simulate the full CUDA local block sort.
Shows step-by-step what happens in a real GPU.
"""
n = len(input_data)
block = BlockSimulator(n)
print(f"\n{'='*70}")
print(f"CUDA LOCAL BLOCK SORT SIMULATION")
print(f"Input: {input_data}")
print(f"Shift: {shift_width} (looking at bits {shift_width}-{shift_width+1})")
print(f"{'='*70}")
# =========================================================================
# STEP 1: Load data (all threads in parallel)
# =========================================================================
print(f"\n[STEP 1] All {n} threads load their element and extract digit")
block.run_kernel(kernel_phase1_load, input_data, shift_width)
print(f" Thread ID: {[t.thread_id for t in block.threads]}")
print(f" my_value: {[t.my_value for t in block.threads]}")
print(f" my_digit: {[t.my_digit for t in block.threads]}")
# =========================================================================
# STEP 2-3: For each digit, build mask and scan (this is the key!)
# =========================================================================
digit_counts = [0, 0, 0, 0]
for target_digit in range(4):
print(f"\n[STEP 2-3] Processing digit {target_digit}")
# All threads build mask in parallel
print(f" All threads: 'Do I have digit {target_digit}?'")
block.run_kernel(kernel_phase2_build_mask_and_scan, target_digit)
mask = [block.smem_mask[i] for i in range(n)]
print(f" Mask: {mask}")
# Barrier - wait for all threads
block.barrier()
# Parallel prefix scan (simulated)
parallel_exclusive_scan_simulation(block, block.smem_mask, block.smem_scan, n)
scan_result = [block.smem_scan[i] for i in range(n)]
total = block.smem_scan[n]
print(f" Scan result: {scan_result} (total: {total})")
digit_counts[target_digit] = total
# Barrier
block.barrier()
# Each thread with this digit saves its position
block.run_kernel(kernel_phase3_save_position, target_digit)
positions = [t.my_position if t.my_digit == target_digit else '-'
for t in block.threads]
print(f" Positions: {positions}")
print(f"\n[STEP 4] Compute digit offsets")
print(f" Digit counts: {digit_counts}")
# Exclusive scan of digit counts to get offsets
digit_offsets = [0, 0, 0, 0]
running = 0
for d in range(4):
digit_offsets[d] = running
block.smem_digit_offsets[d] = running
running += digit_counts[d]
print(f" Digit offsets: {digit_offsets}")
# =========================================================================
# STEP 5: Scatter (all threads in parallel)
# =========================================================================
print(f"\n[STEP 5] All threads scatter to sorted positions")
block.run_kernel(kernel_phase5_scatter)
for t in block.threads:
new_pos = digit_offsets[t.my_digit] + t.my_position
print(f" Thread {t.thread_id}: value={t.my_value}, digit={t.my_digit}, "
f"offset={digit_offsets[t.my_digit]}, position={t.my_position} "
f"→ output[{new_pos}]")
# =========================================================================
# Extract results
# =========================================================================
sorted_output = [block.smem_output[i] for i in range(n)]
sorted_prefix = [block.smem_output_prefix[i] for i in range(n)]
print(f"\n[RESULT]")
print(f" Sorted output: {sorted_output}")
print(f" Local prefixes: {sorted_prefix}")
print(f" Digit counts: {digit_counts}")
return sorted_output, sorted_prefix, digit_counts
# =============================================================================
# Why This Is Parallel (No Sequential Counting!)
# =============================================================================
def explain_parallelism():
"""
Explain why the mask-and-scan approach is truly parallel.
"""
print("""
╔══════════════════════════════════════════════════════════════════════════════╗
║ WHY MASK-AND-SCAN IS PARALLEL ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ SEQUENTIAL APPROACH (what we had before): ║
║ ───────────────────────────────────────── ║
║ ║
║ counter = 0 ║
║ for i in range(n): # Must be sequential! ║
║ position[i] = counter # Read counter ║
║ counter += 1 # Modify counter ║
║ # Next iteration depends on this! ║
║ ║
║ Problem: Iteration i+1 needs the result of iteration i. ║
║ ║
║ ──────────────────────────────────────────────────────────────────────── ║
║ ║
║ PARALLEL APPROACH (mask and scan): ║
║ ────────────────────────────────── ║
║ ║
║ Step 1: Build mask (ALL THREADS SIMULTANEOUSLY) ║
║ ┌─────────────────────────────────────────────────────────┐ ║
║ │ Thread 0: mask[0] = (my_digit == target) ? 1 : 0 │ ║
║ │ Thread 1: mask[1] = (my_digit == target) ? 1 : 0 │ PARALLEL ║
║ │ Thread 2: mask[2] = (my_digit == target) ? 1 : 0 │ (no deps) ║
║ │ ... │ ║
║ └─────────────────────────────────────────────────────────┘ ║
║ ║
║ Step 2: Parallel prefix scan (O(log n) steps, all threads active) ║
║ ┌─────────────────────────────────────────────────────────┐ ║
║ │ Input: [0, 0, 0, 1, 0, 0, 0, 1] │ ║
║ │ ↓ (log₂(8) = 3 parallel steps) │ ║
║ │ Output: [0, 0, 0, 0, 1, 1, 1, 1] │ ║
║ └─────────────────────────────────────────────────────────┘ ║
║ ║
║ Step 3: Read position (ALL THREADS SIMULTANEOUSLY) ║
║ ┌─────────────────────────────────────────────────────────┐ ║
║ │ Thread 3: my_position = scan[3] = 0 "I'm 1st!" │ ║
║ │ Thread 7: my_position = scan[7] = 1 "I'm 2nd!" │ PARALLEL ║
║ │ (Other threads don't have this digit, they skip) │ (no deps) ║
║ └─────────────────────────────────────────────────────────┘ ║
║ ║
║ No thread depends on another thread's result! ║
║ Each thread discovers its position INDEPENDENTLY. ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
""")
# =============================================================================
# Test and Demo
# =============================================================================
def test_cuda_sort():
"""Test the CUDA-style sort matches expected output."""
input_data = [7, 2, 5, 0, 3, 6, 1, 4]
sorted_out, prefix_out, counts = cuda_local_block_sort(input_data, shift_width=0)
expected_sorted = [0, 4, 5, 1, 2, 6, 7, 3]
expected_counts = [2, 2, 2, 2]
assert sorted_out == expected_sorted, f"Expected {expected_sorted}, got {sorted_out}"
assert counts == expected_counts, f"Expected {expected_counts}, got {counts}"
print("\n✓ Test passed!")
if __name__ == "__main__":
explain_parallelism()
test_cuda_sort()
"""
Full Pseudo-CUDA implementation of radix sort with MULTIPLE BLOCKS.
Shows all three phases:
1. Local block sort (each block sorts its portion)
2. Scan block sums (compute global offsets)
3. Global shuffle (scatter to final positions)
Using 16 elements split into 2 blocks of 8.
"""
from typing import List, Tuple, Dict
from dataclasses import dataclass, field
# =============================================================================
# Simulated GPU Memory
# =============================================================================
class GlobalMemory:
"""
Simulates GPU global memory - accessible by ALL blocks.
This is like cudaMalloc'd memory.
"""
def __init__(self):
self.buffers: Dict[str, List[int]] = {}
def allocate(self, name: str, size: int, init_value: int = 0):
self.buffers[name] = [init_value] * size
def read(self, name: str, idx: int) -> int:
return self.buffers[name][idx]
def write(self, name: str, idx: int, value: int):
self.buffers[name][idx] = value
def get_buffer(self, name: str) -> List[int]:
return self.buffers[name]
class SharedMemory:
"""Simulates CUDA shared memory (only visible within one block)."""
def __init__(self, size: int):
self.data = [0] * size
def __getitem__(self, idx: int) -> int:
return self.data[idx]
def __setitem__(self, idx: int, value: int):
self.data[idx] = value
@dataclass
class ThreadContext:
"""Per-thread state."""
thread_id: int # threadIdx.x (local to block)
block_id: int # blockIdx.x
global_id: int # blockIdx.x * blockDim.x + threadIdx.x
# Thread-private variables
my_value: int = 0
my_digit: int = 0
my_local_position: int = 0 # Position within digit group IN THIS BLOCK
my_global_position: int = 0 # Final position in output array
# =============================================================================
# Helper Functions
# =============================================================================
def extract_digit(value: int, shift: int) -> int:
return (value >> shift) & 0b11
def exclusive_scan(arr: List[int]) -> List[int]:
"""Exclusive prefix sum."""
result = [0] * len(arr)
running = 0
for i in range(len(arr)):
result[i] = running
running += arr[i]
return result
# =============================================================================
# PHASE 1: Local Block Sort (runs on each block independently)
# =============================================================================
def phase1_local_sort(
block_id: int,
block_size: int,
global_mem: GlobalMemory,
shift_width: int
) -> Tuple[List[int], List[int], List[int]]:
"""
Each block sorts its portion of the input.
Returns:
sorted_data: Locally sorted elements
local_prefix: Each element's position within its digit group
digit_counts: [count_0, count_1, count_2, count_3] for this block
"""
print(f"\n [Block {block_id}] Local sort")
# Create threads for this block
threads = [
ThreadContext(
thread_id=t,
block_id=block_id,
global_id=block_id * block_size + t
)
for t in range(block_size)
]
# Shared memory for this block
smem_mask = [0] * block_size
smem_scan = [0] * (block_size + 1)
# Step 1: Each thread loads its element
input_buffer = global_mem.get_buffer("input")
for t in threads:
t.my_value = input_buffer[t.global_id]
t.my_digit = extract_digit(t.my_value, shift_width)
print(f" Thread IDs: {[t.thread_id for t in threads]}")
print(f" Values: {[t.my_value for t in threads]}")
print(f" Digits: {[t.my_digit for t in threads]}")
# Step 2-3: Mask and scan for each digit
digit_counts = [0, 0, 0, 0]
for target_digit in range(4):
# Build mask (parallel)
for t in threads:
smem_mask[t.thread_id] = 1 if t.my_digit == target_digit else 0
# Exclusive scan (parallel in real CUDA)
running = 0
for i in range(block_size):
smem_scan[i] = running
running += smem_mask[i]
smem_scan[block_size] = running # Total
digit_counts[target_digit] = running
# Each thread with this digit saves its position
for t in threads:
if t.my_digit == target_digit:
t.my_local_position = smem_scan[t.thread_id]
print(f" Digit counts: {digit_counts}")
# Step 4: Compute digit offsets within block
digit_offsets = exclusive_scan(digit_counts)
# Step 5: Scatter to sorted positions (within block)
sorted_data = [0] * block_size
local_prefix = [0] * block_size
for t in threads:
new_pos = digit_offsets[t.my_digit] + t.my_local_position
sorted_data[new_pos] = t.my_value
local_prefix[new_pos] = t.my_local_position
print(f" Sorted: {sorted_data}")
print(f" Local prefix:{local_prefix}")
return sorted_data, local_prefix, digit_counts
# =============================================================================
# PHASE 2: Scan Block Sums (compute global offsets)
# =============================================================================
def phase2_scan_block_sums(
all_block_counts: List[List[int]], # [block][digit]
num_blocks: int
) -> List[List[int]]:
"""
Compute global starting position for each (digit, block) pair.
In CUDA, this could be done by:
- A separate kernel launch
- Or integrated with Phase 3
The key insight: we need to know where EACH block's elements
for EACH digit should go in the final output.
Returns:
global_offsets[block][digit] = starting position in output
"""
print(f"\n{'='*70}")
print("PHASE 2: Scan Block Sums")
print(f"{'='*70}")
# Flatten: all digit-0 counts, then all digit-1 counts, etc.
# This is how the CUDA algorithm typically lays it out:
# [block0_d0, block1_d0, block0_d1, block1_d1, ...]
print("\n Block counts per digit:")
for d in range(4):
counts = [all_block_counts[b][d] for b in range(num_blocks)]
print(f" Digit {d}: {counts}")
# Compute global offsets
# Order: all digit 0s first, then digit 1s, then 2s, then 3s
global_offsets = [[0] * 4 for _ in range(num_blocks)]
running = 0
for digit in range(4):
for block in range(num_blocks):
global_offsets[block][digit] = running
running += all_block_counts[block][digit]
print("\n Global offsets (where each block's digit group starts):")
for d in range(4):
offsets = [global_offsets[b][d] for b in range(num_blocks)]
print(f" Digit {d}: blocks start at {offsets}")
return global_offsets
# =============================================================================
# PHASE 3: Global Shuffle (scatter to final positions)
# =============================================================================
def phase3_global_shuffle(
block_id: int,
block_size: int,
sorted_data: List[int],
local_prefix: List[int],
global_offsets: List[List[int]],
global_mem: GlobalMemory,
shift_width: int
):
"""
Each block scatters its elements to final global positions.
global_position = global_offsets[block][digit] + local_prefix
In CUDA, ALL blocks run this in parallel.
"""
print(f"\n [Block {block_id}] Global shuffle")
for i in range(block_size):
value = sorted_data[i]
digit = extract_digit(value, shift_width)
local_pos = local_prefix[i]
global_pos = global_offsets[block_id][digit] + local_pos
global_mem.write("output", global_pos, value)
print(f" Element {i}: value={value}, digit={digit}, "
f"global_offset={global_offsets[block_id][digit]}, "
f"local_pos={local_pos} → output[{global_pos}]")
# =============================================================================
# Complete Multi-Block Radix Sort Pass
# =============================================================================
def radix_sort_pass_multiblock(
input_data: List[int],
block_size: int,
shift_width: int
) -> List[int]:
"""
One complete pass of radix sort across multiple blocks.
"""
n = len(input_data)
num_blocks = (n + block_size - 1) // block_size
# Pad input to multiple of block_size
padded_input = input_data + [0xFFFFFFFF] * (num_blocks * block_size - n)
# Initialize global memory
global_mem = GlobalMemory()
global_mem.allocate("input", len(padded_input))
global_mem.allocate("output", len(padded_input))
for i, v in enumerate(padded_input):
global_mem.write("input", i, v)
print(f"\n{'#'*70}")
print(f"# RADIX SORT PASS (shift={shift_width}, looking at bits {shift_width}-{shift_width+1})")
print(f"# Input: {input_data}")
print(f"# {num_blocks} blocks of {block_size} elements each")
print(f"{'#'*70}")
# =========================================================================
# PHASE 1: Local Block Sort (all blocks run in parallel on GPU)
# =========================================================================
print(f"\n{'='*70}")
print("PHASE 1: Local Block Sort (each block independently)")
print(f"{'='*70}")
all_sorted_blocks = []
all_local_prefixes = []
all_block_counts = []
for block_id in range(num_blocks):
sorted_data, local_prefix, digit_counts = phase1_local_sort(
block_id, block_size, global_mem, shift_width
)
all_sorted_blocks.append(sorted_data)
all_local_prefixes.append(local_prefix)
all_block_counts.append(digit_counts)
# =========================================================================
# PHASE 2: Scan Block Sums
# =========================================================================
global_offsets = phase2_scan_block_sums(all_block_counts, num_blocks)
# =========================================================================
# PHASE 3: Global Shuffle (all blocks run in parallel on GPU)
# =========================================================================
print(f"\n{'='*70}")
print("PHASE 3: Global Shuffle (each block independently)")
print(f"{'='*70}")
for block_id in range(num_blocks):
phase3_global_shuffle(
block_id, block_size,
all_sorted_blocks[block_id],
all_local_prefixes[block_id],
global_offsets,
global_mem,
shift_width
)
# =========================================================================
# Result
# =========================================================================
output = global_mem.get_buffer("output")[:n]
print(f"\n{'='*70}")
print("RESULT")
print(f"{'='*70}")
print(f" Output: {output}")
return output
# =============================================================================
# Visualization of the Full Algorithm
# =============================================================================
def visualize_algorithm():
"""ASCII art explanation of the multi-block algorithm."""
print("""
╔══════════════════════════════════════════════════════════════════════════════╗
║ MULTI-BLOCK RADIX SORT OVERVIEW ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ INPUT: 16 elements, 2 blocks of 8 ║
║ ┌─────────────────────────────┬─────────────────────────────┐ ║
║ │ Block 0 (elements 0-7) │ Block 1 (elements 8-15)│ ║
║ │ [7, 2, 5, 0, 3, 6, 1, 4] │ [10, 9, 8, 11, 14, 13, 12, 15] │ ║
║ └─────────────────────────────┴─────────────────────────────┘ ║
║ │ │ ║
║ ▼ ▼ ║
║ ════════════════════════════════════════════════════════════════════════ ║
║ PHASE 1: Local Sort (blocks work independently, IN PARALLEL) ║
║ ════════════════════════════════════════════════════════════════════════ ║
║ │ │ ║
║ ▼ ▼ ║
║ ┌─────────────────────────────┬─────────────────────────────┐ ║
║ │ Sorted: [0,4, 5,1, 2,6, 7,3]│ Sorted: [8,12, 9,13, 10,14, 11,15]│ ║
║ │ Counts: [2,2,2,2] │ Counts: [2,2,2,2] │ ║
║ │ d0 d1 d2 d3 │ d0 d1 d2 d3 │ ║
║ └─────────────────────────────┴─────────────────────────────┘ ║
║ │ │ ║
║ └──────────────┬───────────────┘ ║
║ ▼ ║
║ ════════════════════════════════════════════════════════════════════════ ║
║ PHASE 2: Scan Block Sums (compute global offsets) ║
║ ════════════════════════════════════════════════════════════════════════ ║
║ ║
║ Digit 0: Block0 has 2, Block1 has 2 ║
║ Digit 1: Block0 has 2, Block1 has 2 ║
║ Digit 2: Block0 has 2, Block1 has 2 ║
║ Digit 3: Block0 has 2, Block1 has 2 ║
║ ║
║ Global offsets (where does each block's digit group go?): ║
║ ┌────────┬─────────┬─────────┐ ║
║ │ Digit │ Block 0 │ Block 1 │ ║
║ ├────────┼─────────┼─────────┤ ║
║ │ 0 │ 0 │ 2 │ ← Digit 0: positions 0-3 ║
║ │ 1 │ 4 │ 6 │ ← Digit 1: positions 4-7 ║
║ │ 2 │ 8 │ 10 │ ← Digit 2: positions 8-11 ║
║ │ 3 │ 12 │ 14 │ ← Digit 3: positions 12-15 ║
║ └────────┴─────────┴─────────┘ ║
║ ║
║ ▼ ║
║ ════════════════════════════════════════════════════════════════════════ ║
║ PHASE 3: Global Shuffle (blocks work independently, IN PARALLEL) ║
║ ════════════════════════════════════════════════════════════════════════ ║
║ ║
║ Block 0: value=0 (digit 0) → offset[0][0]=0 + local_pos=0 → output[0] ║
║ Block 0: value=4 (digit 0) → offset[0][0]=0 + local_pos=1 → output[1] ║
║ Block 1: value=8 (digit 0) → offset[1][0]=2 + local_pos=0 → output[2] ║
║ Block 1: value=12(digit 0) → offset[1][0]=2 + local_pos=1 → output[3] ║
║ ... and so on for digits 1, 2, 3 ... ║
║ ║
║ ▼ ║
║ ┌───────────────────────────────────────────────────────────────────────┐ ║
║ │ OUTPUT (sorted by 2-bit digit): │ ║
║ │ [0, 4, 8, 12, 5, 1, 9, 13, 2, 6, 10, 14, 7, 3, 11, 15] │ ║
║ │ └─ digit 0 ─┘ └─ digit 1 ─┘ └─ digit 2 ─┘ └─ digit 3 ─┘ │ ║
║ └───────────────────────────────────────────────────────────────────────┘ ║
║ ║
║ After 16 passes (2 bits × 16 = 32 bits), array is fully sorted! ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
""")
# =============================================================================
# Main Demo
# =============================================================================
def main():
visualize_algorithm()
# 16 elements, 2 blocks of 8
input_data = [7, 2, 5, 0, 3, 6, 1, 4, 10, 9, 8, 11, 14, 13, 12, 15]
print("\n" + "="*70)
print("RUNNING SIMULATION WITH ACTUAL DATA")
print("="*70)
result = radix_sort_pass_multiblock(
input_data=input_data,
block_size=8,
shift_width=0 # Look at bits 0-1
)
# Verify: after one pass, elements should be grouped by their 2-bit digit
print(f"\n{'='*70}")
print("VERIFICATION")
print(f"{'='*70}")
digits = [extract_digit(v, 0) for v in result]
print(f" Result digits: {digits}")
print(f" Should be: [0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3]")
expected_digits = [0]*4 + [1]*4 + [2]*4 + [3]*4
if digits == expected_digits:
print(" ✓ Correct! Elements are grouped by digit.")
else:
print(" ✗ Something went wrong!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment