Skip to content

Instantly share code, notes, and snippets.

@NTT123
Created December 30, 2025 01:45
Show Gist options
  • Select an option

  • Save NTT123/a3854dcaed9cf1b1a787c9d8f5bd8433 to your computer and use it in GitHub Desktop.

Select an option

Save NTT123/a3854dcaed9cf1b1a787c9d8f5bd8433 to your computer and use it in GitHub Desktop.
// All-Gather using Cooperative Groups grid.sync() with vectorized memory access
// RTX 5090: 170 SMs, 1 block per SM, 16 bytes (uint4) per SM to share
// Persistent kernel: multiple rounds of all-gather, each with different buffer
#include <cuda_runtime.h>
#include <cooperative_groups.h>
#include <stdio.h>
#include <climits>
namespace cg = cooperative_groups;
#define NUM_SMS 170
#define BYTES_PER_SM 16 // One uint4 (128-bit) per SM
#define NUM_ITERS 10
// Global buffer for inter-SM communication - one uint4 per SM per iteration
__device__ uint4 g_buffer[NUM_ITERS][NUM_SMS];
// Store uint4 with volatile semantics using PTX (bypass cache)
__device__ __forceinline__ void store_uint4_volatile(uint4* ptr, uint4 val) {
asm volatile(
"st.volatile.global.v4.u32 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)
: "memory"
);
}
// Load uint4 with volatile semantics using PTX (bypass cache)
__device__ __forceinline__ uint4 load_uint4_volatile(const uint4* ptr) {
uint4 ret;
asm volatile(
"ld.volatile.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(ptr)
);
return ret;
}
// Load uint4 with L2 cache hint using PTX
__device__ __forceinline__ uint4 load_uint4_l2(const uint4* ptr) {
uint4 ret;
asm volatile(
"ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(ptr)
);
return ret;
}
// Per-SM results: cycles per iteration and verification pass/fail
__device__ unsigned long long g_cycles[NUM_ITERS][NUM_SMS];
__device__ int g_passed[NUM_SMS];
__global__ void all_gather_kernel() {
cg::grid_group grid = cg::this_grid();
const int bid = blockIdx.x;
const int tid = threadIdx.x;
// Shared memory to gather all data (170 * 16 bytes = 2720 bytes)
__shared__ uint4 smem[NUM_SMS];
int total_errors = 0;
// Initialize all g_buffer slots with sentinel value before main loop
if (tid == 0) {
for (int i = 0; i < NUM_ITERS; i++) {
store_uint4_volatile(&g_buffer[i][bid], {999999, 999999, 999999, 999999});
}
}
grid.sync();
for (int iter = 0; iter < NUM_ITERS; iter++) {
unsigned long long t_start = clock64();
// Step 1: Thread 0 writes this SM's uint4 (16 bytes) to global buffer with L2 cache hint
// Data encoding: x=iter*1000+bid*100, y=+1, z=+2, w=+3
if (tid == 0) {
uint4 data;
data.x = iter * 1000 + bid * 100 + 0;
data.y = iter * 1000 + bid * 100 + 1;
data.z = iter * 1000 + bid * 100 + 2;
data.w = iter * 1000 + bid * 100 + 3;
store_uint4_volatile(&g_buffer[iter][bid], data); // Single 128-bit store with L2 hint
}
// __syncthreads();
// Step 2: Grid-wide synchronization
// grid.sync();
// Step 3: Gather all data into shared memory using spin-polling until valid data arrives
// Each thread loads one uint4 (170 SMs, threads 0-169 each load one)
// First try L2 cached load, fallback to volatile if sentinel detected
if (tid < NUM_SMS) {
uint4 data = load_uint4_volatile(&g_buffer[iter][tid]);
// If sentinel detected, poll with volatile until valid
while (data.w == 999999 || data.x == 999999) {
data = load_uint4_volatile(&g_buffer[iter][tid]);
}
smem[tid] = data;
}
__syncthreads();
unsigned long long t_end = clock64();
// Record cycles for this iteration
if (tid == 0) {
g_cycles[iter][bid] = t_end - t_start;
}
__syncthreads();
// Step 4: Verify correctness
int errors = 0;
for (int sm = tid; sm < NUM_SMS; sm += blockDim.x) {
uint4 val = smem[sm];
if (val.x != (unsigned)(iter * 1000 + sm * 100 + 0)) errors++;
if (val.y != (unsigned)(iter * 1000 + sm * 100 + 1)) errors++;
if (val.z != (unsigned)(iter * 1000 + sm * 100 + 2)) errors++;
if (val.w != (unsigned)(iter * 1000 + sm * 100 + 3)) errors++;
}
// Reduce errors within block
__shared__ int iter_errors;
if (tid == 0) iter_errors = 0;
__syncthreads();
if (errors > 0) {
atomicAdd(&iter_errors, errors);
}
__syncthreads();
total_errors += iter_errors;
// Sync before next iteration
grid.sync();
}
// Thread 0 records final pass/fail
if (tid == 0) {
g_passed[bid] = (total_errors == 0) ? 1 : 0;
}
}
int main() {
// Query device properties for cooperative launch
int device;
cudaGetDevice(&device);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device);
printf("Device: %s\n", prop.name);
printf("SMs: %d\n", prop.multiProcessorCount);
if (!prop.cooperativeLaunch) {
printf("ERROR: Cooperative launch not supported!\n");
return 1;
}
// Use 256 threads per block
int threads_per_block = 256;
// Launch with cooperative groups
void* kernel_args[] = {};
cudaLaunchCooperativeKernel(
(void*)all_gather_kernel,
dim3(NUM_SMS),
dim3(threads_per_block),
kernel_args
);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("Kernel error: %s\n", cudaGetErrorString(err));
return 1;
}
// Copy results back
unsigned long long h_cycles[NUM_ITERS][NUM_SMS];
int h_passed[NUM_SMS];
cudaMemcpyFromSymbol(h_cycles, g_cycles, sizeof(h_cycles));
cudaMemcpyFromSymbol(h_passed, g_passed, sizeof(h_passed));
// Check all passed
int all_passed = 1;
for (int i = 0; i < NUM_SMS; i++) {
if (!h_passed[i]) all_passed = 0;
}
// Analyze results per iteration
printf("\n=== All-Gather Results (Persistent Kernel) ===\n");
printf("Blocks: %d, Threads/block: %d, Iterations: %d\n",
NUM_SMS, threads_per_block, NUM_ITERS);
printf("Data per SM: %d bytes\n", BYTES_PER_SM);
printf("Total data gathered per iteration: %d bytes\n", NUM_SMS * BYTES_PER_SM);
printf("\n");
double clock_ghz = 2.82; // RTX 5090 boost clock
// Compute per-iteration max cycles
unsigned long long iter_max[NUM_ITERS];
unsigned long long overall_min = ULLONG_MAX, overall_max = 0;
unsigned long long sum = 0;
for (int iter = 0; iter < NUM_ITERS; iter++) {
unsigned long long max_cycles = 0;
for (int sm = 0; sm < NUM_SMS; sm++) {
if (h_cycles[iter][sm] > max_cycles) {
max_cycles = h_cycles[iter][sm];
}
}
iter_max[iter] = max_cycles;
if (max_cycles < overall_min) overall_min = max_cycles;
if (max_cycles > overall_max) overall_max = max_cycles;
sum += max_cycles;
}
double avg = (double)sum / NUM_ITERS;
printf("Per-iteration max cycles:\n");
printf("Iter | Max Cycles | Time (us)\n");
printf("-----|------------|----------\n");
for (int i = 0; i < NUM_ITERS; i++) {
printf("%4d | %10llu | %9.3f\n", i, iter_max[i], iter_max[i] / (clock_ghz * 1000));
}
printf("\nSummary (excluding iter 0):\n");
unsigned long long sum_ex0 = sum - iter_max[0];
unsigned long long min_ex0 = ULLONG_MAX, max_ex0 = 0;
for (int i = 1; i < NUM_ITERS; i++) {
if (iter_max[i] < min_ex0) min_ex0 = iter_max[i];
if (iter_max[i] > max_ex0) max_ex0 = iter_max[i];
}
double avg_ex0 = (double)sum_ex0 / (NUM_ITERS - 1);
printf(" Min: %llu cycles (%.3f us)\n", min_ex0, min_ex0 / (clock_ghz * 1000));
printf(" Max: %llu cycles (%.3f us)\n", max_ex0, max_ex0 / (clock_ghz * 1000));
printf(" Avg: %.1f cycles (%.3f us)\n", avg_ex0, avg_ex0 / (clock_ghz * 1000));
printf("\nVerification: %s\n", all_passed ? "PASSED" : "FAILED");
return all_passed ? 0 : 1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment