Last active
January 31, 2026 15:54
-
-
Save Jokeren/2072e88f59b93460ed2a6b594c5cc1aa to your computer and use it in GitHub Desktop.
64-bit layout conversion
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| import sysconfig | |
| from pathlib import Path | |
| from typing import Iterable | |
| import ctypes | |
| REPO_ROOT = Path(__file__).resolve().parent | |
| def _maybe_prepend_local_triton() -> None: | |
| triton_src = REPO_ROOT / "python" | |
| if triton_src.exists(): | |
| sys.path.insert(0, str(triton_src)) | |
| def _find_cuda_home() -> Path: | |
| env = os.environ | |
| for k in ("CUDA_HOME", "CUDA_PATH"): | |
| if k in env: | |
| p = Path(env[k]) | |
| if (p / "include" / "cuda.h").exists(): | |
| return p | |
| nvcc = shutil.which("nvcc") | |
| if nvcc: | |
| p = Path(nvcc).resolve() | |
| cuda_home = p.parent.parent | |
| if (cuda_home / "include" / "cuda.h").exists(): | |
| return cuda_home | |
| for p in (Path("/usr/local/cuda"), Path("/opt/cuda")): | |
| if (p / "include" / "cuda.h").exists(): | |
| return p | |
| raise RuntimeError( | |
| "Could not find CUDA Toolkit (cuda.h). Set CUDA_HOME to your CUDA install." | |
| ) | |
| def _ext_suffix() -> str: | |
| suffix = sysconfig.get_config_var("EXT_SUFFIX") | |
| if not suffix: | |
| raise RuntimeError("Python EXT_SUFFIX is not set; cannot build extension.") | |
| return str(suffix) | |
| def _run(cmd: list[str]) -> None: | |
| subprocess.check_call(cmd) | |
| def _build_and_import_ext() -> object: | |
| import importlib.util | |
| src_path = REPO_ROOT / "tensor_metric_ext.cpp" | |
| if not src_path.exists(): | |
| raise RuntimeError(f"Missing {src_path}") | |
| cuda_home = _find_cuda_home() | |
| py_include = Path(sysconfig.get_paths()["include"]) | |
| key_material = "\n".join( | |
| [ | |
| str(src_path.resolve()), | |
| src_path.read_text(encoding="utf-8"), | |
| sys.version, | |
| sysconfig.get_platform(), | |
| str(py_include.resolve()), | |
| str((cuda_home / "include").resolve()), | |
| ] | |
| ).encode("utf-8") | |
| build_key = hashlib.sha256(key_material).hexdigest()[:16] | |
| build_dir = REPO_ROOT / ".tensor_metric_repro_build" / build_key | |
| build_dir.mkdir(parents=True, exist_ok=True) | |
| so_path = build_dir / f"tensor_metric_ext{_ext_suffix()}" | |
| if not so_path.exists(): | |
| cxx = os.environ.get("CXX") or shutil.which("c++") or shutil.which("g++") or shutil.which("clang++") | |
| if not cxx: | |
| raise RuntimeError("Could not find a C++ compiler (set CXX).") | |
| lib_dirs: list[Path] = [] | |
| for cand in ( | |
| cuda_home / "lib64", | |
| cuda_home / "lib64" / "stubs", | |
| cuda_home / "lib", | |
| cuda_home / "lib" / "stubs", | |
| ): | |
| if cand.exists(): | |
| lib_dirs.append(cand) | |
| cmd = [ | |
| cxx, | |
| "-O3", | |
| "-shared", | |
| "-fPIC", | |
| "-std=c++17", | |
| str(src_path), | |
| "-o", | |
| str(so_path), | |
| f"-I{py_include}", | |
| f"-I{cuda_home / 'include'}", | |
| ] | |
| for d in lib_dirs: | |
| cmd.append(f"-L{d}") | |
| cmd.append("-lcuda") | |
| _run(cmd) | |
| spec = importlib.util.spec_from_file_location("tensor_metric_ext", so_path) | |
| if spec is None or spec.loader is None: | |
| raise RuntimeError(f"Failed to load extension from {so_path}") | |
| mod = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(mod) | |
| return mod | |
| def _u64_view(ext: object, buf_capsule: object) -> ctypes.Array: | |
| host_ptr = int(ext.buf_host_ptr(buf_capsule)) | |
| nelems = int(ext.buf_nelems(buf_capsule)) | |
| arr_ty = ctypes.c_uint64 * nelems | |
| return arr_ty.from_address(host_ptr) | |
| def _expected_after_one_write( | |
| *, | |
| ring_size_words: int, | |
| sentinel: int, | |
| initial_offset: int, | |
| metric_id: int, | |
| metric_value_start: int, | |
| metric_value_size: int, | |
| ) -> tuple[list[int], int]: | |
| ring = [sentinel] * ring_size_words | |
| off = initial_offset % ring_size_words | |
| ring[off] = metric_id & 0xFFFFFFFFFFFFFFFF | |
| off = (off + 1) % ring_size_words | |
| for i in range(metric_value_size): | |
| ring[(off + i) % ring_size_words] = (metric_value_start + i) & 0xFFFFFFFFFFFFFFFF | |
| new_off = (off + metric_value_size) % ring_size_words | |
| return ring, new_off | |
| def _verify_one_case( | |
| *, | |
| ext: object, | |
| kernel: int, | |
| stream: int, | |
| ring_size_words: int, | |
| metric_value_size: int, | |
| initial_offset: int, | |
| metric_id: int, | |
| ) -> None: | |
| if ring_size_words <= 0: | |
| raise ValueError("ring_size_words must be > 0") | |
| if metric_value_size < 0: | |
| raise ValueError("metric_value_size must be >= 0") | |
| sentinel = 0xDEADBEEFDEADBEEF | |
| ring = ext.alloc_pinned_host_u64(ring_size_words, zero=False) | |
| ext.fill_u64(ring, sentinel) | |
| offset = ext.alloc_pinned_host_u64(1, zero=True) | |
| ext.write_u64(offset, 0, initial_offset) | |
| # Even if metric_value_size == 0, pass a valid pointer (it won't be read). | |
| metric_value_nelems = max(1, metric_value_size) | |
| metric_value = ext.alloc_pinned_host_u64(metric_value_nelems, zero=False) | |
| metric_value_start = 0xABC00000 | |
| ext.iota_u64(metric_value, metric_value_start) | |
| ext.launch_tensor_metric_kernel( | |
| kernel=kernel, | |
| stream=stream, | |
| device_ptr=int(ext.buf_device_ptr(ring)), | |
| device_offset_ptr=int(ext.buf_device_ptr(offset)), | |
| size_words=ring_size_words, | |
| metric_id=metric_id, | |
| metric_value_ptr=int(ext.buf_device_ptr(metric_value)), | |
| metric_value_size=metric_value_size, | |
| iters=1, | |
| ) | |
| ext.stream_synchronize(stream) | |
| expected_ring, expected_offset = _expected_after_one_write( | |
| ring_size_words=ring_size_words, | |
| sentinel=sentinel, | |
| initial_offset=initial_offset, | |
| metric_id=metric_id, | |
| metric_value_start=metric_value_start, | |
| metric_value_size=metric_value_size, | |
| ) | |
| ring_view = _u64_view(ext, ring) | |
| for i in range(ring_size_words): | |
| got = int(ring_view[i]) | |
| exp = int(expected_ring[i]) | |
| if got != exp: | |
| raise AssertionError( | |
| f"ring mismatch at idx={i}: got=0x{got:016x} expected=0x{exp:016x} " | |
| f"(ring_size_words={ring_size_words}, metric_value_size={metric_value_size}, initial_offset={initial_offset})" | |
| ) | |
| off_view = _u64_view(ext, offset) | |
| got_off = int(off_view[0]) | |
| if got_off != expected_offset: | |
| raise AssertionError( | |
| f"offset mismatch: got={got_off} expected={expected_offset} " | |
| f"(ring_size_words={ring_size_words}, metric_value_size={metric_value_size}, initial_offset={initial_offset})" | |
| ) | |
| def _bench_cpp_launch( | |
| *, | |
| ext: object, | |
| kernel: int, | |
| stream: int, | |
| ring_size_words: int, | |
| metric_value_size: int, | |
| iters: int, | |
| warmup: int, | |
| ) -> float: | |
| sentinel = 0 | |
| ring = ext.alloc_pinned_host_u64(ring_size_words, zero=True) | |
| ext.fill_u64(ring, sentinel) | |
| offset = ext.alloc_pinned_host_u64(1, zero=True) | |
| ext.write_u64(offset, 0, 0) | |
| metric_value_nelems = max(1, metric_value_size) | |
| metric_value = ext.alloc_pinned_host_u64(metric_value_nelems, zero=True) | |
| ext.iota_u64(metric_value, 0x12340000) | |
| args = dict( | |
| kernel=kernel, | |
| stream=stream, | |
| device_ptr=int(ext.buf_device_ptr(ring)), | |
| device_offset_ptr=int(ext.buf_device_ptr(offset)), | |
| size_words=ring_size_words, | |
| metric_id=1, | |
| metric_value_ptr=int(ext.buf_device_ptr(metric_value)), | |
| metric_value_size=metric_value_size, | |
| ) | |
| for _ in range(max(0, warmup)): | |
| ext.launch_tensor_metric_kernel(**args, iters=1) | |
| ext.stream_synchronize(stream) | |
| start = int(ext.event_create()) | |
| end = int(ext.event_create()) | |
| try: | |
| ext.event_record(start, stream) | |
| ext.launch_tensor_metric_kernel(**args, iters=iters) | |
| ext.event_record(end, stream) | |
| ext.event_synchronize(end) | |
| ms = float(ext.event_elapsed_ms(start, end)) | |
| finally: | |
| ext.event_destroy(start) | |
| ext.event_destroy(end) | |
| if iters <= 0: | |
| raise ValueError("iters must be > 0") | |
| return (ms * 1e3) / iters | |
| def _iter_cases() -> Iterable[tuple[int, int, int]]: | |
| # (ring_size_words, metric_value_size, initial_offset) | |
| yield (256, 3, 0) | |
| yield (256, 127, 11) | |
| yield (257, 129, 255) | |
| yield (512, 0, 42) # empty metric value | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Reproducer for tensor_metric_kernel correctness/perf.") | |
| parser.add_argument("--no-verify", action="store_true", help="Skip correctness checks.") | |
| parser.add_argument("--ring-size", type=int, default=1 << 12, help="Ring size in uint64 elements.") | |
| parser.add_argument("--value-size", type=int, default=256, help="Metric value size in uint64 elements.") | |
| parser.add_argument("--iters", type=int, default=10000, help="Benchmark iterations (C++ launches).") | |
| parser.add_argument("--warmup", type=int, default=100, help="Warmup launches before benchmark.") | |
| parser.add_argument("--stream", type=int, default=0, help="CUDA stream handle (0 = default stream).") | |
| parser.add_argument( | |
| "--use-torch-stream", | |
| action="store_true", | |
| help="Use torch.cuda.current_stream() as the stream handle (overrides --stream).", | |
| ) | |
| args = parser.parse_args() | |
| _maybe_prepend_local_triton() | |
| try: | |
| import triton | |
| import triton.language as tl | |
| from triton import MockTensor | |
| except Exception as e: # noqa: BLE001 | |
| raise RuntimeError( | |
| "Failed to import triton. Run this in an environment where Triton is built/installed " | |
| "(e.g. `pip install -e .` from this repo, or set PYTHONPATH appropriately)." | |
| ) from e | |
| ext = _build_and_import_ext() | |
| if args.use_torch_stream: | |
| try: | |
| import torch | |
| except Exception as e: # noqa: BLE001 | |
| raise RuntimeError("--use-torch-stream requires torch to be importable.") from e | |
| args.stream = int(torch.cuda.current_stream().cuda_stream) | |
| @triton.jit | |
| def tensor_metric_kernel( | |
| device_ptr, | |
| device_offset_ptr, | |
| size: tl.uint64, | |
| metric_id: tl.uint64, | |
| metric_value_ptr, | |
| metric_value_size: tl.uint64, | |
| ): | |
| BLOCK_SIZE: tl.constexpr = 256 | |
| device_offset = tl.load(device_offset_ptr) | |
| tl.store(device_ptr + device_offset, metric_id) | |
| device_offset = (device_offset + 1) % size | |
| num_iters = tl.cdiv(metric_value_size, BLOCK_SIZE) | |
| offsets = tl.arange(0, BLOCK_SIZE) | |
| for i in tl.range(0, num_iters): | |
| cur_offsets = offsets + i * BLOCK_SIZE | |
| mask = cur_offsets < metric_value_size | |
| metric_value = tl.load(metric_value_ptr + cur_offsets, mask=mask) | |
| tl.store(device_ptr + (device_offset + cur_offsets) % size, metric_value, mask=mask) | |
| tl.debug_barrier() | |
| device_offset = (device_offset + metric_value_size) % size | |
| tl.store(device_offset_ptr, device_offset) | |
| # Compile once, then launch via cuLaunchKernel in the C++ extension. | |
| mock_ptr = MockTensor(tl.uint64) | |
| compiled = tensor_metric_kernel.warmup( | |
| mock_ptr, | |
| mock_ptr, | |
| 1024, | |
| 1, | |
| mock_ptr, | |
| 1, | |
| grid=(1,), | |
| num_warps=4, | |
| ) | |
| compiled._init_handles() | |
| kernel = int(compiled.function) | |
| if not args.no_verify: | |
| for ring_size_words, metric_value_size, initial_offset in _iter_cases(): | |
| _verify_one_case( | |
| ext=ext, | |
| kernel=kernel, | |
| stream=args.stream, | |
| ring_size_words=ring_size_words, | |
| metric_value_size=metric_value_size, | |
| initial_offset=initial_offset, | |
| metric_id=0x4242, | |
| ) | |
| print("correctness: OK") | |
| us_per = _bench_cpp_launch( | |
| ext=ext, | |
| kernel=kernel, | |
| stream=args.stream, | |
| ring_size_words=args.ring_size, | |
| metric_value_size=args.value_size, | |
| iters=args.iters, | |
| warmup=args.warmup, | |
| ) | |
| words = 1 + max(0, args.value_size) | |
| bytes_per_launch = words * 8 | |
| print( | |
| f"bench: {us_per:.3f} us/launch " | |
| f"({bytes_per_launch} bytes written, ring_size_words={args.ring_size}, value_size={args.value_size}, iters={args.iters})" | |
| ) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #define PY_SSIZE_T_CLEAN | |
| #include <Python.h> | |
| #include <cstddef> | |
| #include <cstdint> | |
| #include <cstring> | |
| #include <mutex> | |
| #include <string> | |
| #if __has_include(<cuda.h>) | |
| #include <cuda.h> | |
| #else | |
| #error "cuda.h not found. Please set CUDA_HOME to a CUDA Toolkit install." | |
| #endif | |
| namespace { | |
| constexpr const char *kPinnedBufferCapsuleName = | |
| "tensor_metric_ext.PinnedHostBufferU64"; | |
| struct PinnedHostBufferU64 { | |
| void *host_ptr = nullptr; | |
| CUdeviceptr device_ptr = 0; | |
| uint64_t nelems = 0; | |
| }; | |
| bool ensure_driver_initialized(std::string *error) { | |
| static std::once_flag init_flag; | |
| static CUresult init_res = CUDA_ERROR_UNKNOWN; | |
| std::call_once(init_flag, []() { init_res = cuInit(0); }); | |
| if (init_res != CUDA_SUCCESS) { | |
| const char *name = nullptr; | |
| const char *msg = nullptr; | |
| cuGetErrorName(init_res, &name); | |
| cuGetErrorString(init_res, &msg); | |
| if (error) { | |
| *error = "cuInit failed: "; | |
| *error += (name ? name : "UNKNOWN"); | |
| *error += ": "; | |
| *error += (msg ? msg : "no message"); | |
| } | |
| return false; | |
| } | |
| return true; | |
| } | |
| PyObject *raise_cuda(const char *what, CUresult res) { | |
| const char *name = nullptr; | |
| const char *msg = nullptr; | |
| cuGetErrorName(res, &name); | |
| cuGetErrorString(res, &msg); | |
| return PyErr_Format(PyExc_RuntimeError, "%s: %s: %s", what, | |
| (name ? name : "UNKNOWN"), | |
| (msg ? msg : "no message")); | |
| } | |
| PinnedHostBufferU64 *get_buf(PyObject *capsule) { | |
| return static_cast<PinnedHostBufferU64 *>( | |
| PyCapsule_GetPointer(capsule, kPinnedBufferCapsuleName)); | |
| } | |
| void capsule_destructor(PyObject *capsule) { | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return; | |
| } | |
| if (buf->host_ptr) { | |
| // Best-effort cleanup; don't throw from destructor. | |
| (void)cuMemFreeHost(buf->host_ptr); | |
| } | |
| delete buf; | |
| } | |
| PyObject *alloc_pinned_host_u64(PyObject *, PyObject *args, PyObject *kwargs) { | |
| unsigned long long nelems_ull = 0; | |
| int zero = 1; | |
| int write_combined = 0; | |
| static const char *kwnames[] = {"nelems", "zero", "write_combined", nullptr}; | |
| if (!PyArg_ParseTupleAndKeywords(args, kwargs, "K|pp", const_cast<char **>(kwnames), | |
| &nelems_ull, &zero, &write_combined)) { | |
| return nullptr; | |
| } | |
| if (nelems_ull == 0) { | |
| return PyErr_Format(PyExc_ValueError, "nelems must be > 0"); | |
| } | |
| if (nelems_ull > (static_cast<unsigned long long>(SIZE_MAX) / sizeof(uint64_t))) { | |
| return PyErr_Format(PyExc_OverflowError, "nelems too large"); | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUcontext ctx = nullptr; | |
| CUresult res = cuCtxGetCurrent(&ctx); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuCtxGetCurrent failed", res); | |
| } | |
| if (ctx == nullptr) { | |
| return PyErr_Format( | |
| PyExc_RuntimeError, | |
| "No current CUDA context on this thread. Initialize CUDA first " | |
| "(e.g. import torch; torch.cuda.init())."); | |
| } | |
| auto *buf = new PinnedHostBufferU64(); | |
| buf->nelems = static_cast<uint64_t>(nelems_ull); | |
| const size_t nbytes = static_cast<size_t>(nelems_ull) * sizeof(uint64_t); | |
| unsigned int flags = CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP; | |
| if (write_combined) { | |
| flags |= CU_MEMHOSTALLOC_WRITECOMBINED; | |
| } | |
| res = cuMemHostAlloc(&buf->host_ptr, nbytes, flags); | |
| if (res != CUDA_SUCCESS) { | |
| delete buf; | |
| return raise_cuda("cuMemHostAlloc failed", res); | |
| } | |
| res = cuMemHostGetDevicePointer(&buf->device_ptr, buf->host_ptr, 0); | |
| if (res != CUDA_SUCCESS) { | |
| (void)cuMemFreeHost(buf->host_ptr); | |
| delete buf; | |
| return raise_cuda("cuMemHostGetDevicePointer failed", res); | |
| } | |
| if (zero) { | |
| std::memset(buf->host_ptr, 0, nbytes); | |
| } | |
| return PyCapsule_New(buf, kPinnedBufferCapsuleName, capsule_destructor); | |
| } | |
| PyObject *buf_host_ptr(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| if (!PyArg_ParseTuple(args, "O", &capsule)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| return PyLong_FromUnsignedLongLong( | |
| static_cast<unsigned long long>(reinterpret_cast<uintptr_t>(buf->host_ptr))); | |
| } | |
| PyObject *buf_device_ptr(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| if (!PyArg_ParseTuple(args, "O", &capsule)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| return PyLong_FromUnsignedLongLong( | |
| static_cast<unsigned long long>(static_cast<uintptr_t>(buf->device_ptr))); | |
| } | |
| PyObject *buf_nelems(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| if (!PyArg_ParseTuple(args, "O", &capsule)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| return PyLong_FromUnsignedLongLong(static_cast<unsigned long long>(buf->nelems)); | |
| } | |
| PyObject *fill_u64(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| unsigned long long value = 0; | |
| if (!PyArg_ParseTuple(args, "OK", &capsule, &value)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| auto *p = reinterpret_cast<uint64_t *>(buf->host_ptr); | |
| for (uint64_t i = 0; i < buf->nelems; ++i) { | |
| p[i] = static_cast<uint64_t>(value); | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *iota_u64(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| unsigned long long start = 0; | |
| if (!PyArg_ParseTuple(args, "OK", &capsule, &start)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| auto *p = reinterpret_cast<uint64_t *>(buf->host_ptr); | |
| uint64_t v = static_cast<uint64_t>(start); | |
| for (uint64_t i = 0; i < buf->nelems; ++i) { | |
| p[i] = v++; | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *read_u64(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| unsigned long long idx = 0; | |
| if (!PyArg_ParseTuple(args, "OK", &capsule, &idx)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| if (idx >= buf->nelems) { | |
| return PyErr_Format(PyExc_IndexError, "index out of range"); | |
| } | |
| const auto *p = reinterpret_cast<const uint64_t *>(buf->host_ptr); | |
| return PyLong_FromUnsignedLongLong(static_cast<unsigned long long>(p[idx])); | |
| } | |
| PyObject *write_u64(PyObject *, PyObject *args) { | |
| PyObject *capsule = nullptr; | |
| unsigned long long idx = 0; | |
| unsigned long long value = 0; | |
| if (!PyArg_ParseTuple(args, "OKK", &capsule, &idx, &value)) { | |
| return nullptr; | |
| } | |
| auto *buf = get_buf(capsule); | |
| if (!buf) { | |
| return nullptr; | |
| } | |
| if (idx >= buf->nelems) { | |
| return PyErr_Format(PyExc_IndexError, "index out of range"); | |
| } | |
| auto *p = reinterpret_cast<uint64_t *>(buf->host_ptr); | |
| p[idx] = static_cast<uint64_t>(value); | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *stream_synchronize(PyObject *, PyObject *args) { | |
| unsigned long long stream_ull = 0; | |
| if (!PyArg_ParseTuple(args, "K", &stream_ull)) { | |
| return nullptr; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUresult res = | |
| cuStreamSynchronize(reinterpret_cast<CUstream>(static_cast<uintptr_t>(stream_ull))); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuStreamSynchronize failed", res); | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *event_create(PyObject *, PyObject *args, PyObject *kwargs) { | |
| int disable_timing = 0; | |
| static const char *kwnames[] = {"disable_timing", nullptr}; | |
| if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p", const_cast<char **>(kwnames), | |
| &disable_timing)) { | |
| return nullptr; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUevent ev = nullptr; | |
| unsigned int flags = disable_timing ? CU_EVENT_DISABLE_TIMING : CU_EVENT_DEFAULT; | |
| CUresult res = cuEventCreate(&ev, flags); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuEventCreate failed", res); | |
| } | |
| return PyLong_FromUnsignedLongLong(static_cast<unsigned long long>(reinterpret_cast<uintptr_t>(ev))); | |
| } | |
| PyObject *event_destroy(PyObject *, PyObject *args) { | |
| unsigned long long ev_ull = 0; | |
| if (!PyArg_ParseTuple(args, "K", &ev_ull)) { | |
| return nullptr; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUevent ev = reinterpret_cast<CUevent>(static_cast<uintptr_t>(ev_ull)); | |
| CUresult res = cuEventDestroy(ev); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuEventDestroy failed", res); | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *event_record(PyObject *, PyObject *args) { | |
| unsigned long long ev_ull = 0; | |
| unsigned long long stream_ull = 0; | |
| if (!PyArg_ParseTuple(args, "KK", &ev_ull, &stream_ull)) { | |
| return nullptr; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUevent ev = reinterpret_cast<CUevent>(static_cast<uintptr_t>(ev_ull)); | |
| CUstream stream = reinterpret_cast<CUstream>(static_cast<uintptr_t>(stream_ull)); | |
| CUresult res = cuEventRecord(ev, stream); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuEventRecord failed", res); | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *event_synchronize(PyObject *, PyObject *args) { | |
| unsigned long long ev_ull = 0; | |
| if (!PyArg_ParseTuple(args, "K", &ev_ull)) { | |
| return nullptr; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUevent ev = reinterpret_cast<CUevent>(static_cast<uintptr_t>(ev_ull)); | |
| CUresult res = cuEventSynchronize(ev); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuEventSynchronize failed", res); | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyObject *event_elapsed_ms(PyObject *, PyObject *args) { | |
| unsigned long long start_ull = 0; | |
| unsigned long long end_ull = 0; | |
| if (!PyArg_ParseTuple(args, "KK", &start_ull, &end_ull)) { | |
| return nullptr; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUevent start = reinterpret_cast<CUevent>(static_cast<uintptr_t>(start_ull)); | |
| CUevent end = reinterpret_cast<CUevent>(static_cast<uintptr_t>(end_ull)); | |
| float ms = 0.0f; | |
| CUresult res = cuEventElapsedTime(&ms, start, end); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuEventElapsedTime failed", res); | |
| } | |
| return PyFloat_FromDouble(static_cast<double>(ms)); | |
| } | |
| PyObject *launch_tensor_metric_kernel(PyObject *, PyObject *args, PyObject *kwargs) { | |
| unsigned long long kernel_ull = 0; | |
| unsigned long long stream_ull = 0; | |
| unsigned long long device_ptr_ull = 0; | |
| unsigned long long device_offset_ptr_ull = 0; | |
| unsigned long long size_words_ull = 0; | |
| unsigned long long metric_id_ull = 0; | |
| unsigned long long metric_value_ptr_ull = 0; | |
| unsigned long long metric_value_size_ull = 0; | |
| unsigned long long iters_ull = 1; | |
| static const char *kwnames[] = {"kernel", "stream", "device_ptr", "device_offset_ptr", "size_words", | |
| "metric_id", "metric_value_ptr", "metric_value_size", "iters", nullptr}; | |
| if (!PyArg_ParseTupleAndKeywords( | |
| args, kwargs, "KKKKKKKK|K", const_cast<char **>(kwnames), &kernel_ull, &stream_ull, | |
| &device_ptr_ull, &device_offset_ptr_ull, &size_words_ull, &metric_id_ull, | |
| &metric_value_ptr_ull, &metric_value_size_ull, &iters_ull)) { | |
| return nullptr; | |
| } | |
| if (iters_ull == 0) { | |
| Py_RETURN_NONE; | |
| } | |
| std::string init_err; | |
| if (!ensure_driver_initialized(&init_err)) { | |
| return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str()); | |
| } | |
| CUfunction kernel = reinterpret_cast<CUfunction>(static_cast<uintptr_t>(kernel_ull)); | |
| CUstream stream = reinterpret_cast<CUstream>(static_cast<uintptr_t>(stream_ull)); | |
| CUdeviceptr device_ptr = static_cast<CUdeviceptr>(static_cast<uintptr_t>(device_ptr_ull)); | |
| CUdeviceptr device_offset_ptr = static_cast<CUdeviceptr>(static_cast<uintptr_t>(device_offset_ptr_ull)); | |
| uint64_t size_words = static_cast<uint64_t>(size_words_ull); | |
| uint64_t metric_id = static_cast<uint64_t>(metric_id_ull); | |
| CUdeviceptr metric_value_ptr = static_cast<CUdeviceptr>(static_cast<uintptr_t>(metric_value_ptr_ull)); | |
| uint64_t metric_value_size = static_cast<uint64_t>(metric_value_size_ull); | |
| void *global_scratch_ptr = nullptr; | |
| void *profile_scratch_ptr = nullptr; | |
| void *kernel_params[] = {reinterpret_cast<void *>(&device_ptr), | |
| reinterpret_cast<void *>(&device_offset_ptr), | |
| reinterpret_cast<void *>(&size_words), | |
| reinterpret_cast<void *>(&metric_id), | |
| reinterpret_cast<void *>(&metric_value_ptr), | |
| reinterpret_cast<void *>(&metric_value_size), | |
| reinterpret_cast<void *>(&global_scratch_ptr), | |
| reinterpret_cast<void *>(&profile_scratch_ptr)}; | |
| for (unsigned long long i = 0; i < iters_ull; ++i) { | |
| CUresult res = | |
| cuLaunchKernel(kernel, 1, 1, 1, 128, 1, 1, 0, stream, kernel_params, nullptr); | |
| if (res != CUDA_SUCCESS) { | |
| return raise_cuda("cuLaunchKernel failed", res); | |
| } | |
| } | |
| Py_RETURN_NONE; | |
| } | |
| PyMethodDef kMethods[] = { | |
| {"alloc_pinned_host_u64", (PyCFunction)alloc_pinned_host_u64, | |
| METH_VARARGS | METH_KEYWORDS, | |
| "alloc_pinned_host_u64(nelems: int, zero: bool=True, write_combined: bool=False) -> capsule\n" | |
| "Allocates mapped, pinned host memory (uint64 elements) and returns a capsule."}, | |
| {"buf_host_ptr", buf_host_ptr, METH_VARARGS, | |
| "buf_host_ptr(capsule) -> int\nReturns host pointer as integer."}, | |
| {"buf_device_ptr", buf_device_ptr, METH_VARARGS, | |
| "buf_device_ptr(capsule) -> int\nReturns device-mapped pointer as integer."}, | |
| {"buf_nelems", buf_nelems, METH_VARARGS, | |
| "buf_nelems(capsule) -> int\nReturns number of uint64 elements."}, | |
| {"fill_u64", fill_u64, METH_VARARGS, | |
| "fill_u64(capsule, value: int) -> None\nFills the buffer (on host) with value."}, | |
| {"iota_u64", iota_u64, METH_VARARGS, | |
| "iota_u64(capsule, start: int) -> None\nFills the buffer with start, start+1, ... (uint64)."}, | |
| {"read_u64", read_u64, METH_VARARGS, | |
| "read_u64(capsule, idx: int) -> int\nReads one uint64 from host buffer."}, | |
| {"write_u64", write_u64, METH_VARARGS, | |
| "write_u64(capsule, idx: int, value: int) -> None\nWrites one uint64 into host buffer."}, | |
| {"stream_synchronize", stream_synchronize, METH_VARARGS, | |
| "stream_synchronize(stream: int) -> None\nSynchronizes a CUDA stream handle."}, | |
| {"event_create", (PyCFunction)event_create, METH_VARARGS | METH_KEYWORDS, | |
| "event_create(disable_timing: bool=False) -> int\nCreates a CUDA event and returns its handle."}, | |
| {"event_destroy", event_destroy, METH_VARARGS, | |
| "event_destroy(event: int) -> None\nDestroys a CUDA event."}, | |
| {"event_record", event_record, METH_VARARGS, | |
| "event_record(event: int, stream: int) -> None\nRecords an event on a stream."}, | |
| {"event_synchronize", event_synchronize, METH_VARARGS, | |
| "event_synchronize(event: int) -> None\nSynchronizes an event."}, | |
| {"event_elapsed_ms", event_elapsed_ms, METH_VARARGS, | |
| "event_elapsed_ms(start: int, end: int) -> float\nReturns elapsed time between events (ms)."}, | |
| {"launch_tensor_metric_kernel", (PyCFunction)launch_tensor_metric_kernel, | |
| METH_VARARGS | METH_KEYWORDS, | |
| "launch_tensor_metric_kernel(kernel: int, stream: int, device_ptr: int, device_offset_ptr: int,\n" | |
| " size_words: int, metric_id: int, metric_value_ptr: int,\n" | |
| " metric_value_size: int, iters: int=1) -> None\n" | |
| "Launches the Triton-compiled tensor_metric_kernel via cuLaunchKernel."}, | |
| {nullptr, nullptr, 0, nullptr}, | |
| }; | |
| PyModuleDef kModule = { | |
| PyModuleDef_HEAD_INIT, | |
| "tensor_metric_ext", | |
| "Small CUDA driver helpers for reproducing tensor_metric_kernel behavior.", | |
| -1, | |
| kMethods, | |
| nullptr, | |
| nullptr, | |
| nullptr, | |
| nullptr, | |
| }; | |
| } // namespace | |
| extern "C" PyObject *PyInit_tensor_metric_ext(void) { return PyModule_Create(&kModule); } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment