Created
December 30, 2025 01:45
-
-
Save NTT123/a3854dcaed9cf1b1a787c9d8f5bd8433 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // All-Gather using Cooperative Groups grid.sync() with vectorized memory access | |
| // RTX 5090: 170 SMs, 1 block per SM, 16 bytes (uint4) per SM to share | |
| // Persistent kernel: multiple rounds of all-gather, each with different buffer | |
| #include <cuda_runtime.h> | |
| #include <cooperative_groups.h> | |
| #include <stdio.h> | |
| #include <climits> | |
| namespace cg = cooperative_groups; | |
| #define NUM_SMS 170 | |
| #define BYTES_PER_SM 16 // One uint4 (128-bit) per SM | |
| #define NUM_ITERS 10 | |
| // Global buffer for inter-SM communication - one uint4 per SM per iteration | |
| __device__ uint4 g_buffer[NUM_ITERS][NUM_SMS]; | |
| // Store uint4 with volatile semantics using PTX (bypass cache) | |
| __device__ __forceinline__ void store_uint4_volatile(uint4* ptr, uint4 val) { | |
| asm volatile( | |
| "st.volatile.global.v4.u32 [%0], {%1, %2, %3, %4};" | |
| : | |
| : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w) | |
| : "memory" | |
| ); | |
| } | |
| // Load uint4 with volatile semantics using PTX (bypass cache) | |
| __device__ __forceinline__ uint4 load_uint4_volatile(const uint4* ptr) { | |
| uint4 ret; | |
| asm volatile( | |
| "ld.volatile.global.v4.u32 {%0, %1, %2, %3}, [%4];" | |
| : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) | |
| : "l"(ptr) | |
| ); | |
| return ret; | |
| } | |
| // Load uint4 with L2 cache hint using PTX | |
| __device__ __forceinline__ uint4 load_uint4_l2(const uint4* ptr) { | |
| uint4 ret; | |
| asm volatile( | |
| "ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];" | |
| : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) | |
| : "l"(ptr) | |
| ); | |
| return ret; | |
| } | |
| // Per-SM results: cycles per iteration and verification pass/fail | |
| __device__ unsigned long long g_cycles[NUM_ITERS][NUM_SMS]; | |
| __device__ int g_passed[NUM_SMS]; | |
| __global__ void all_gather_kernel() { | |
| cg::grid_group grid = cg::this_grid(); | |
| const int bid = blockIdx.x; | |
| const int tid = threadIdx.x; | |
| // Shared memory to gather all data (170 * 16 bytes = 2720 bytes) | |
| __shared__ uint4 smem[NUM_SMS]; | |
| int total_errors = 0; | |
| // Initialize all g_buffer slots with sentinel value before main loop | |
| if (tid == 0) { | |
| for (int i = 0; i < NUM_ITERS; i++) { | |
| store_uint4_volatile(&g_buffer[i][bid], {999999, 999999, 999999, 999999}); | |
| } | |
| } | |
| grid.sync(); | |
| for (int iter = 0; iter < NUM_ITERS; iter++) { | |
| unsigned long long t_start = clock64(); | |
| // Step 1: Thread 0 writes this SM's uint4 (16 bytes) to global buffer with L2 cache hint | |
| // Data encoding: x=iter*1000+bid*100, y=+1, z=+2, w=+3 | |
| if (tid == 0) { | |
| uint4 data; | |
| data.x = iter * 1000 + bid * 100 + 0; | |
| data.y = iter * 1000 + bid * 100 + 1; | |
| data.z = iter * 1000 + bid * 100 + 2; | |
| data.w = iter * 1000 + bid * 100 + 3; | |
| store_uint4_volatile(&g_buffer[iter][bid], data); // Single 128-bit store with L2 hint | |
| } | |
| // __syncthreads(); | |
| // Step 2: Grid-wide synchronization | |
| // grid.sync(); | |
| // Step 3: Gather all data into shared memory using spin-polling until valid data arrives | |
| // Each thread loads one uint4 (170 SMs, threads 0-169 each load one) | |
| // First try L2 cached load, fallback to volatile if sentinel detected | |
| if (tid < NUM_SMS) { | |
| uint4 data = load_uint4_volatile(&g_buffer[iter][tid]); | |
| // If sentinel detected, poll with volatile until valid | |
| while (data.w == 999999 || data.x == 999999) { | |
| data = load_uint4_volatile(&g_buffer[iter][tid]); | |
| } | |
| smem[tid] = data; | |
| } | |
| __syncthreads(); | |
| unsigned long long t_end = clock64(); | |
| // Record cycles for this iteration | |
| if (tid == 0) { | |
| g_cycles[iter][bid] = t_end - t_start; | |
| } | |
| __syncthreads(); | |
| // Step 4: Verify correctness | |
| int errors = 0; | |
| for (int sm = tid; sm < NUM_SMS; sm += blockDim.x) { | |
| uint4 val = smem[sm]; | |
| if (val.x != (unsigned)(iter * 1000 + sm * 100 + 0)) errors++; | |
| if (val.y != (unsigned)(iter * 1000 + sm * 100 + 1)) errors++; | |
| if (val.z != (unsigned)(iter * 1000 + sm * 100 + 2)) errors++; | |
| if (val.w != (unsigned)(iter * 1000 + sm * 100 + 3)) errors++; | |
| } | |
| // Reduce errors within block | |
| __shared__ int iter_errors; | |
| if (tid == 0) iter_errors = 0; | |
| __syncthreads(); | |
| if (errors > 0) { | |
| atomicAdd(&iter_errors, errors); | |
| } | |
| __syncthreads(); | |
| total_errors += iter_errors; | |
| // Sync before next iteration | |
| grid.sync(); | |
| } | |
| // Thread 0 records final pass/fail | |
| if (tid == 0) { | |
| g_passed[bid] = (total_errors == 0) ? 1 : 0; | |
| } | |
| } | |
| int main() { | |
| // Query device properties for cooperative launch | |
| int device; | |
| cudaGetDevice(&device); | |
| cudaDeviceProp prop; | |
| cudaGetDeviceProperties(&prop, device); | |
| printf("Device: %s\n", prop.name); | |
| printf("SMs: %d\n", prop.multiProcessorCount); | |
| if (!prop.cooperativeLaunch) { | |
| printf("ERROR: Cooperative launch not supported!\n"); | |
| return 1; | |
| } | |
| // Use 256 threads per block | |
| int threads_per_block = 256; | |
| // Launch with cooperative groups | |
| void* kernel_args[] = {}; | |
| cudaLaunchCooperativeKernel( | |
| (void*)all_gather_kernel, | |
| dim3(NUM_SMS), | |
| dim3(threads_per_block), | |
| kernel_args | |
| ); | |
| cudaError_t err = cudaDeviceSynchronize(); | |
| if (err != cudaSuccess) { | |
| printf("Kernel error: %s\n", cudaGetErrorString(err)); | |
| return 1; | |
| } | |
| // Copy results back | |
| unsigned long long h_cycles[NUM_ITERS][NUM_SMS]; | |
| int h_passed[NUM_SMS]; | |
| cudaMemcpyFromSymbol(h_cycles, g_cycles, sizeof(h_cycles)); | |
| cudaMemcpyFromSymbol(h_passed, g_passed, sizeof(h_passed)); | |
| // Check all passed | |
| int all_passed = 1; | |
| for (int i = 0; i < NUM_SMS; i++) { | |
| if (!h_passed[i]) all_passed = 0; | |
| } | |
| // Analyze results per iteration | |
| printf("\n=== All-Gather Results (Persistent Kernel) ===\n"); | |
| printf("Blocks: %d, Threads/block: %d, Iterations: %d\n", | |
| NUM_SMS, threads_per_block, NUM_ITERS); | |
| printf("Data per SM: %d bytes\n", BYTES_PER_SM); | |
| printf("Total data gathered per iteration: %d bytes\n", NUM_SMS * BYTES_PER_SM); | |
| printf("\n"); | |
| double clock_ghz = 2.82; // RTX 5090 boost clock | |
| // Compute per-iteration max cycles | |
| unsigned long long iter_max[NUM_ITERS]; | |
| unsigned long long overall_min = ULLONG_MAX, overall_max = 0; | |
| unsigned long long sum = 0; | |
| for (int iter = 0; iter < NUM_ITERS; iter++) { | |
| unsigned long long max_cycles = 0; | |
| for (int sm = 0; sm < NUM_SMS; sm++) { | |
| if (h_cycles[iter][sm] > max_cycles) { | |
| max_cycles = h_cycles[iter][sm]; | |
| } | |
| } | |
| iter_max[iter] = max_cycles; | |
| if (max_cycles < overall_min) overall_min = max_cycles; | |
| if (max_cycles > overall_max) overall_max = max_cycles; | |
| sum += max_cycles; | |
| } | |
| double avg = (double)sum / NUM_ITERS; | |
| printf("Per-iteration max cycles:\n"); | |
| printf("Iter | Max Cycles | Time (us)\n"); | |
| printf("-----|------------|----------\n"); | |
| for (int i = 0; i < NUM_ITERS; i++) { | |
| printf("%4d | %10llu | %9.3f\n", i, iter_max[i], iter_max[i] / (clock_ghz * 1000)); | |
| } | |
| printf("\nSummary (excluding iter 0):\n"); | |
| unsigned long long sum_ex0 = sum - iter_max[0]; | |
| unsigned long long min_ex0 = ULLONG_MAX, max_ex0 = 0; | |
| for (int i = 1; i < NUM_ITERS; i++) { | |
| if (iter_max[i] < min_ex0) min_ex0 = iter_max[i]; | |
| if (iter_max[i] > max_ex0) max_ex0 = iter_max[i]; | |
| } | |
| double avg_ex0 = (double)sum_ex0 / (NUM_ITERS - 1); | |
| printf(" Min: %llu cycles (%.3f us)\n", min_ex0, min_ex0 / (clock_ghz * 1000)); | |
| printf(" Max: %llu cycles (%.3f us)\n", max_ex0, max_ex0 / (clock_ghz * 1000)); | |
| printf(" Avg: %.1f cycles (%.3f us)\n", avg_ex0, avg_ex0 / (clock_ghz * 1000)); | |
| printf("\nVerification: %s\n", all_passed ? "PASSED" : "FAILED"); | |
| return all_passed ? 0 : 1; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment