Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Last active September 18, 2025 14:02
Show Gist options
  • Select an option

  • Save youkaichao/972f16f143c277945d8934645a1b4f96 to your computer and use it in GitHub Desktop.

Select an option

Save youkaichao/972f16f143c277945d8934645a1b4f96 to your computer and use it in GitHub Desktop.
import torch
from torch.utils.cpp_extension import load_inline
src = r"""
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
// Return SM count for a specific device (or current device if device_index < 0)
int64_t num_sms(int64_t device_index = -1) {
if (!at::cuda::is_available()) return 0;
int dev = (device_index >= 0) ? static_cast<int>(device_index)
: at::cuda::current_device();
// Use LibTorch helpers to fetch cudaDeviceProp
const cudaDeviceProp* prop =
(dev == at::cuda::current_device())
? at::cuda::getCurrentDeviceProperties()
: at::cuda::getDeviceProperties(dev);
return prop ? static_cast<int64_t>(prop->multiProcessorCount) : 0;
}
// Return SM counts for all visible CUDA devices
std::vector<int64_t> all_sms() {
std::vector<int64_t> out;
if (!at::cuda::is_available()) return out;
int n = static_cast<int>(at::cuda::getNumGPUs());
out.reserve(n);
for (int d = 0; d < n; ++d) {
const auto* p = at::cuda::getDeviceProperties(d);
out.push_back(p ? static_cast<int64_t>(p->multiProcessorCount) : 0);
}
return out;
}
"""
smext = load_inline(
name="smcount_ext",
cpp_sources=[src],
functions=["num_sms", "all_sms"],
with_cuda=True,
extra_cflags=["-O3"],
extra_cuda_cflags=["-O3"],
verbose=True,
)
print("Current device SMs:", smext.num_sms(0))
print("All device SMs :", smext.all_sms())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment