-
-
Save youkaichao/972f16f143c277945d8934645a1b4f96 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.utils.cpp_extension import load_inline | |
| src = r""" | |
| #include <torch/extension.h> | |
| #include <ATen/cuda/CUDAContext.h> | |
| #include <cuda_runtime.h> | |
| // Return SM count for a specific device (or current device if device_index < 0) | |
| int64_t num_sms(int64_t device_index = -1) { | |
| if (!at::cuda::is_available()) return 0; | |
| int dev = (device_index >= 0) ? static_cast<int>(device_index) | |
| : at::cuda::current_device(); | |
| // Use LibTorch helpers to fetch cudaDeviceProp | |
| const cudaDeviceProp* prop = | |
| (dev == at::cuda::current_device()) | |
| ? at::cuda::getCurrentDeviceProperties() | |
| : at::cuda::getDeviceProperties(dev); | |
| return prop ? static_cast<int64_t>(prop->multiProcessorCount) : 0; | |
| } | |
| // Return SM counts for all visible CUDA devices | |
| std::vector<int64_t> all_sms() { | |
| std::vector<int64_t> out; | |
| if (!at::cuda::is_available()) return out; | |
| int n = static_cast<int>(at::cuda::getNumGPUs()); | |
| out.reserve(n); | |
| for (int d = 0; d < n; ++d) { | |
| const auto* p = at::cuda::getDeviceProperties(d); | |
| out.push_back(p ? static_cast<int64_t>(p->multiProcessorCount) : 0); | |
| } | |
| return out; | |
| } | |
| """ | |
| smext = load_inline( | |
| name="smcount_ext", | |
| cpp_sources=[src], | |
| functions=["num_sms", "all_sms"], | |
| with_cuda=True, | |
| extra_cflags=["-O3"], | |
| extra_cuda_cflags=["-O3"], | |
| verbose=True, | |
| ) | |
| print("Current device SMs:", smext.num_sms(0)) | |
| print("All device SMs :", smext.all_sms()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment