Skip to content

Instantly share code, notes, and snippets.

@Jokeren
Last active January 31, 2026 15:54
Show Gist options
  • Select an option

  • Save Jokeren/2072e88f59b93460ed2a6b594c5cc1aa to your computer and use it in GitHub Desktop.

Select an option

Save Jokeren/2072e88f59b93460ed2a6b594c5cc1aa to your computer and use it in GitHub Desktop.
64-bit layout conversion
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import hashlib
import os
import shutil
import subprocess
import sys
import sysconfig
from pathlib import Path
from typing import Iterable
import ctypes
REPO_ROOT = Path(__file__).resolve().parent
def _maybe_prepend_local_triton() -> None:
triton_src = REPO_ROOT / "python"
if triton_src.exists():
sys.path.insert(0, str(triton_src))
def _find_cuda_home() -> Path:
env = os.environ
for k in ("CUDA_HOME", "CUDA_PATH"):
if k in env:
p = Path(env[k])
if (p / "include" / "cuda.h").exists():
return p
nvcc = shutil.which("nvcc")
if nvcc:
p = Path(nvcc).resolve()
cuda_home = p.parent.parent
if (cuda_home / "include" / "cuda.h").exists():
return cuda_home
for p in (Path("/usr/local/cuda"), Path("/opt/cuda")):
if (p / "include" / "cuda.h").exists():
return p
raise RuntimeError(
"Could not find CUDA Toolkit (cuda.h). Set CUDA_HOME to your CUDA install."
)
def _ext_suffix() -> str:
suffix = sysconfig.get_config_var("EXT_SUFFIX")
if not suffix:
raise RuntimeError("Python EXT_SUFFIX is not set; cannot build extension.")
return str(suffix)
def _run(cmd: list[str]) -> None:
subprocess.check_call(cmd)
def _build_and_import_ext() -> object:
import importlib.util
src_path = REPO_ROOT / "tensor_metric_ext.cpp"
if not src_path.exists():
raise RuntimeError(f"Missing {src_path}")
cuda_home = _find_cuda_home()
py_include = Path(sysconfig.get_paths()["include"])
key_material = "\n".join(
[
str(src_path.resolve()),
src_path.read_text(encoding="utf-8"),
sys.version,
sysconfig.get_platform(),
str(py_include.resolve()),
str((cuda_home / "include").resolve()),
]
).encode("utf-8")
build_key = hashlib.sha256(key_material).hexdigest()[:16]
build_dir = REPO_ROOT / ".tensor_metric_repro_build" / build_key
build_dir.mkdir(parents=True, exist_ok=True)
so_path = build_dir / f"tensor_metric_ext{_ext_suffix()}"
if not so_path.exists():
cxx = os.environ.get("CXX") or shutil.which("c++") or shutil.which("g++") or shutil.which("clang++")
if not cxx:
raise RuntimeError("Could not find a C++ compiler (set CXX).")
lib_dirs: list[Path] = []
for cand in (
cuda_home / "lib64",
cuda_home / "lib64" / "stubs",
cuda_home / "lib",
cuda_home / "lib" / "stubs",
):
if cand.exists():
lib_dirs.append(cand)
cmd = [
cxx,
"-O3",
"-shared",
"-fPIC",
"-std=c++17",
str(src_path),
"-o",
str(so_path),
f"-I{py_include}",
f"-I{cuda_home / 'include'}",
]
for d in lib_dirs:
cmd.append(f"-L{d}")
cmd.append("-lcuda")
_run(cmd)
spec = importlib.util.spec_from_file_location("tensor_metric_ext", so_path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Failed to load extension from {so_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _u64_view(ext: object, buf_capsule: object) -> ctypes.Array:
host_ptr = int(ext.buf_host_ptr(buf_capsule))
nelems = int(ext.buf_nelems(buf_capsule))
arr_ty = ctypes.c_uint64 * nelems
return arr_ty.from_address(host_ptr)
def _expected_after_one_write(
*,
ring_size_words: int,
sentinel: int,
initial_offset: int,
metric_id: int,
metric_value_start: int,
metric_value_size: int,
) -> tuple[list[int], int]:
ring = [sentinel] * ring_size_words
off = initial_offset % ring_size_words
ring[off] = metric_id & 0xFFFFFFFFFFFFFFFF
off = (off + 1) % ring_size_words
for i in range(metric_value_size):
ring[(off + i) % ring_size_words] = (metric_value_start + i) & 0xFFFFFFFFFFFFFFFF
new_off = (off + metric_value_size) % ring_size_words
return ring, new_off
def _verify_one_case(
*,
ext: object,
kernel: int,
stream: int,
ring_size_words: int,
metric_value_size: int,
initial_offset: int,
metric_id: int,
) -> None:
if ring_size_words <= 0:
raise ValueError("ring_size_words must be > 0")
if metric_value_size < 0:
raise ValueError("metric_value_size must be >= 0")
sentinel = 0xDEADBEEFDEADBEEF
ring = ext.alloc_pinned_host_u64(ring_size_words, zero=False)
ext.fill_u64(ring, sentinel)
offset = ext.alloc_pinned_host_u64(1, zero=True)
ext.write_u64(offset, 0, initial_offset)
# Even if metric_value_size == 0, pass a valid pointer (it won't be read).
metric_value_nelems = max(1, metric_value_size)
metric_value = ext.alloc_pinned_host_u64(metric_value_nelems, zero=False)
metric_value_start = 0xABC00000
ext.iota_u64(metric_value, metric_value_start)
ext.launch_tensor_metric_kernel(
kernel=kernel,
stream=stream,
device_ptr=int(ext.buf_device_ptr(ring)),
device_offset_ptr=int(ext.buf_device_ptr(offset)),
size_words=ring_size_words,
metric_id=metric_id,
metric_value_ptr=int(ext.buf_device_ptr(metric_value)),
metric_value_size=metric_value_size,
iters=1,
)
ext.stream_synchronize(stream)
expected_ring, expected_offset = _expected_after_one_write(
ring_size_words=ring_size_words,
sentinel=sentinel,
initial_offset=initial_offset,
metric_id=metric_id,
metric_value_start=metric_value_start,
metric_value_size=metric_value_size,
)
ring_view = _u64_view(ext, ring)
for i in range(ring_size_words):
got = int(ring_view[i])
exp = int(expected_ring[i])
if got != exp:
raise AssertionError(
f"ring mismatch at idx={i}: got=0x{got:016x} expected=0x{exp:016x} "
f"(ring_size_words={ring_size_words}, metric_value_size={metric_value_size}, initial_offset={initial_offset})"
)
off_view = _u64_view(ext, offset)
got_off = int(off_view[0])
if got_off != expected_offset:
raise AssertionError(
f"offset mismatch: got={got_off} expected={expected_offset} "
f"(ring_size_words={ring_size_words}, metric_value_size={metric_value_size}, initial_offset={initial_offset})"
)
def _bench_cpp_launch(
*,
ext: object,
kernel: int,
stream: int,
ring_size_words: int,
metric_value_size: int,
iters: int,
warmup: int,
) -> float:
sentinel = 0
ring = ext.alloc_pinned_host_u64(ring_size_words, zero=True)
ext.fill_u64(ring, sentinel)
offset = ext.alloc_pinned_host_u64(1, zero=True)
ext.write_u64(offset, 0, 0)
metric_value_nelems = max(1, metric_value_size)
metric_value = ext.alloc_pinned_host_u64(metric_value_nelems, zero=True)
ext.iota_u64(metric_value, 0x12340000)
args = dict(
kernel=kernel,
stream=stream,
device_ptr=int(ext.buf_device_ptr(ring)),
device_offset_ptr=int(ext.buf_device_ptr(offset)),
size_words=ring_size_words,
metric_id=1,
metric_value_ptr=int(ext.buf_device_ptr(metric_value)),
metric_value_size=metric_value_size,
)
for _ in range(max(0, warmup)):
ext.launch_tensor_metric_kernel(**args, iters=1)
ext.stream_synchronize(stream)
start = int(ext.event_create())
end = int(ext.event_create())
try:
ext.event_record(start, stream)
ext.launch_tensor_metric_kernel(**args, iters=iters)
ext.event_record(end, stream)
ext.event_synchronize(end)
ms = float(ext.event_elapsed_ms(start, end))
finally:
ext.event_destroy(start)
ext.event_destroy(end)
if iters <= 0:
raise ValueError("iters must be > 0")
return (ms * 1e3) / iters
def _iter_cases() -> Iterable[tuple[int, int, int]]:
# (ring_size_words, metric_value_size, initial_offset)
yield (256, 3, 0)
yield (256, 127, 11)
yield (257, 129, 255)
yield (512, 0, 42) # empty metric value
def main() -> None:
parser = argparse.ArgumentParser(description="Reproducer for tensor_metric_kernel correctness/perf.")
parser.add_argument("--no-verify", action="store_true", help="Skip correctness checks.")
parser.add_argument("--ring-size", type=int, default=1 << 12, help="Ring size in uint64 elements.")
parser.add_argument("--value-size", type=int, default=256, help="Metric value size in uint64 elements.")
parser.add_argument("--iters", type=int, default=10000, help="Benchmark iterations (C++ launches).")
parser.add_argument("--warmup", type=int, default=100, help="Warmup launches before benchmark.")
parser.add_argument("--stream", type=int, default=0, help="CUDA stream handle (0 = default stream).")
parser.add_argument(
"--use-torch-stream",
action="store_true",
help="Use torch.cuda.current_stream() as the stream handle (overrides --stream).",
)
args = parser.parse_args()
_maybe_prepend_local_triton()
try:
import triton
import triton.language as tl
from triton import MockTensor
except Exception as e: # noqa: BLE001
raise RuntimeError(
"Failed to import triton. Run this in an environment where Triton is built/installed "
"(e.g. `pip install -e .` from this repo, or set PYTHONPATH appropriately)."
) from e
ext = _build_and_import_ext()
if args.use_torch_stream:
try:
import torch
except Exception as e: # noqa: BLE001
raise RuntimeError("--use-torch-stream requires torch to be importable.") from e
args.stream = int(torch.cuda.current_stream().cuda_stream)
@triton.jit
def tensor_metric_kernel(
device_ptr,
device_offset_ptr,
size: tl.uint64,
metric_id: tl.uint64,
metric_value_ptr,
metric_value_size: tl.uint64,
):
BLOCK_SIZE: tl.constexpr = 256
device_offset = tl.load(device_offset_ptr)
tl.store(device_ptr + device_offset, metric_id)
device_offset = (device_offset + 1) % size
num_iters = tl.cdiv(metric_value_size, BLOCK_SIZE)
offsets = tl.arange(0, BLOCK_SIZE)
for i in tl.range(0, num_iters):
cur_offsets = offsets + i * BLOCK_SIZE
mask = cur_offsets < metric_value_size
metric_value = tl.load(metric_value_ptr + cur_offsets, mask=mask)
tl.store(device_ptr + (device_offset + cur_offsets) % size, metric_value, mask=mask)
tl.debug_barrier()
device_offset = (device_offset + metric_value_size) % size
tl.store(device_offset_ptr, device_offset)
# Compile once, then launch via cuLaunchKernel in the C++ extension.
mock_ptr = MockTensor(tl.uint64)
compiled = tensor_metric_kernel.warmup(
mock_ptr,
mock_ptr,
1024,
1,
mock_ptr,
1,
grid=(1,),
num_warps=4,
)
compiled._init_handles()
kernel = int(compiled.function)
if not args.no_verify:
for ring_size_words, metric_value_size, initial_offset in _iter_cases():
_verify_one_case(
ext=ext,
kernel=kernel,
stream=args.stream,
ring_size_words=ring_size_words,
metric_value_size=metric_value_size,
initial_offset=initial_offset,
metric_id=0x4242,
)
print("correctness: OK")
us_per = _bench_cpp_launch(
ext=ext,
kernel=kernel,
stream=args.stream,
ring_size_words=args.ring_size,
metric_value_size=args.value_size,
iters=args.iters,
warmup=args.warmup,
)
words = 1 + max(0, args.value_size)
bytes_per_launch = words * 8
print(
f"bench: {us_per:.3f} us/launch "
f"({bytes_per_launch} bytes written, ring_size_words={args.ring_size}, value_size={args.value_size}, iters={args.iters})"
)
if __name__ == "__main__":
main()
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <mutex>
#include <string>
#if __has_include(<cuda.h>)
#include <cuda.h>
#else
#error "cuda.h not found. Please set CUDA_HOME to a CUDA Toolkit install."
#endif
namespace {
constexpr const char *kPinnedBufferCapsuleName =
"tensor_metric_ext.PinnedHostBufferU64";
struct PinnedHostBufferU64 {
void *host_ptr = nullptr;
CUdeviceptr device_ptr = 0;
uint64_t nelems = 0;
};
bool ensure_driver_initialized(std::string *error) {
static std::once_flag init_flag;
static CUresult init_res = CUDA_ERROR_UNKNOWN;
std::call_once(init_flag, []() { init_res = cuInit(0); });
if (init_res != CUDA_SUCCESS) {
const char *name = nullptr;
const char *msg = nullptr;
cuGetErrorName(init_res, &name);
cuGetErrorString(init_res, &msg);
if (error) {
*error = "cuInit failed: ";
*error += (name ? name : "UNKNOWN");
*error += ": ";
*error += (msg ? msg : "no message");
}
return false;
}
return true;
}
PyObject *raise_cuda(const char *what, CUresult res) {
const char *name = nullptr;
const char *msg = nullptr;
cuGetErrorName(res, &name);
cuGetErrorString(res, &msg);
return PyErr_Format(PyExc_RuntimeError, "%s: %s: %s", what,
(name ? name : "UNKNOWN"),
(msg ? msg : "no message"));
}
PinnedHostBufferU64 *get_buf(PyObject *capsule) {
return static_cast<PinnedHostBufferU64 *>(
PyCapsule_GetPointer(capsule, kPinnedBufferCapsuleName));
}
void capsule_destructor(PyObject *capsule) {
auto *buf = get_buf(capsule);
if (!buf) {
return;
}
if (buf->host_ptr) {
// Best-effort cleanup; don't throw from destructor.
(void)cuMemFreeHost(buf->host_ptr);
}
delete buf;
}
PyObject *alloc_pinned_host_u64(PyObject *, PyObject *args, PyObject *kwargs) {
unsigned long long nelems_ull = 0;
int zero = 1;
int write_combined = 0;
static const char *kwnames[] = {"nelems", "zero", "write_combined", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "K|pp", const_cast<char **>(kwnames),
&nelems_ull, &zero, &write_combined)) {
return nullptr;
}
if (nelems_ull == 0) {
return PyErr_Format(PyExc_ValueError, "nelems must be > 0");
}
if (nelems_ull > (static_cast<unsigned long long>(SIZE_MAX) / sizeof(uint64_t))) {
return PyErr_Format(PyExc_OverflowError, "nelems too large");
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUcontext ctx = nullptr;
CUresult res = cuCtxGetCurrent(&ctx);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuCtxGetCurrent failed", res);
}
if (ctx == nullptr) {
return PyErr_Format(
PyExc_RuntimeError,
"No current CUDA context on this thread. Initialize CUDA first "
"(e.g. import torch; torch.cuda.init()).");
}
auto *buf = new PinnedHostBufferU64();
buf->nelems = static_cast<uint64_t>(nelems_ull);
const size_t nbytes = static_cast<size_t>(nelems_ull) * sizeof(uint64_t);
unsigned int flags = CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP;
if (write_combined) {
flags |= CU_MEMHOSTALLOC_WRITECOMBINED;
}
res = cuMemHostAlloc(&buf->host_ptr, nbytes, flags);
if (res != CUDA_SUCCESS) {
delete buf;
return raise_cuda("cuMemHostAlloc failed", res);
}
res = cuMemHostGetDevicePointer(&buf->device_ptr, buf->host_ptr, 0);
if (res != CUDA_SUCCESS) {
(void)cuMemFreeHost(buf->host_ptr);
delete buf;
return raise_cuda("cuMemHostGetDevicePointer failed", res);
}
if (zero) {
std::memset(buf->host_ptr, 0, nbytes);
}
return PyCapsule_New(buf, kPinnedBufferCapsuleName, capsule_destructor);
}
PyObject *buf_host_ptr(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
if (!PyArg_ParseTuple(args, "O", &capsule)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
return PyLong_FromUnsignedLongLong(
static_cast<unsigned long long>(reinterpret_cast<uintptr_t>(buf->host_ptr)));
}
PyObject *buf_device_ptr(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
if (!PyArg_ParseTuple(args, "O", &capsule)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
return PyLong_FromUnsignedLongLong(
static_cast<unsigned long long>(static_cast<uintptr_t>(buf->device_ptr)));
}
PyObject *buf_nelems(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
if (!PyArg_ParseTuple(args, "O", &capsule)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
return PyLong_FromUnsignedLongLong(static_cast<unsigned long long>(buf->nelems));
}
PyObject *fill_u64(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
unsigned long long value = 0;
if (!PyArg_ParseTuple(args, "OK", &capsule, &value)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
auto *p = reinterpret_cast<uint64_t *>(buf->host_ptr);
for (uint64_t i = 0; i < buf->nelems; ++i) {
p[i] = static_cast<uint64_t>(value);
}
Py_RETURN_NONE;
}
PyObject *iota_u64(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
unsigned long long start = 0;
if (!PyArg_ParseTuple(args, "OK", &capsule, &start)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
auto *p = reinterpret_cast<uint64_t *>(buf->host_ptr);
uint64_t v = static_cast<uint64_t>(start);
for (uint64_t i = 0; i < buf->nelems; ++i) {
p[i] = v++;
}
Py_RETURN_NONE;
}
PyObject *read_u64(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
unsigned long long idx = 0;
if (!PyArg_ParseTuple(args, "OK", &capsule, &idx)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
if (idx >= buf->nelems) {
return PyErr_Format(PyExc_IndexError, "index out of range");
}
const auto *p = reinterpret_cast<const uint64_t *>(buf->host_ptr);
return PyLong_FromUnsignedLongLong(static_cast<unsigned long long>(p[idx]));
}
PyObject *write_u64(PyObject *, PyObject *args) {
PyObject *capsule = nullptr;
unsigned long long idx = 0;
unsigned long long value = 0;
if (!PyArg_ParseTuple(args, "OKK", &capsule, &idx, &value)) {
return nullptr;
}
auto *buf = get_buf(capsule);
if (!buf) {
return nullptr;
}
if (idx >= buf->nelems) {
return PyErr_Format(PyExc_IndexError, "index out of range");
}
auto *p = reinterpret_cast<uint64_t *>(buf->host_ptr);
p[idx] = static_cast<uint64_t>(value);
Py_RETURN_NONE;
}
PyObject *stream_synchronize(PyObject *, PyObject *args) {
unsigned long long stream_ull = 0;
if (!PyArg_ParseTuple(args, "K", &stream_ull)) {
return nullptr;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUresult res =
cuStreamSynchronize(reinterpret_cast<CUstream>(static_cast<uintptr_t>(stream_ull)));
if (res != CUDA_SUCCESS) {
return raise_cuda("cuStreamSynchronize failed", res);
}
Py_RETURN_NONE;
}
PyObject *event_create(PyObject *, PyObject *args, PyObject *kwargs) {
int disable_timing = 0;
static const char *kwnames[] = {"disable_timing", nullptr};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|p", const_cast<char **>(kwnames),
&disable_timing)) {
return nullptr;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUevent ev = nullptr;
unsigned int flags = disable_timing ? CU_EVENT_DISABLE_TIMING : CU_EVENT_DEFAULT;
CUresult res = cuEventCreate(&ev, flags);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuEventCreate failed", res);
}
return PyLong_FromUnsignedLongLong(static_cast<unsigned long long>(reinterpret_cast<uintptr_t>(ev)));
}
PyObject *event_destroy(PyObject *, PyObject *args) {
unsigned long long ev_ull = 0;
if (!PyArg_ParseTuple(args, "K", &ev_ull)) {
return nullptr;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUevent ev = reinterpret_cast<CUevent>(static_cast<uintptr_t>(ev_ull));
CUresult res = cuEventDestroy(ev);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuEventDestroy failed", res);
}
Py_RETURN_NONE;
}
PyObject *event_record(PyObject *, PyObject *args) {
unsigned long long ev_ull = 0;
unsigned long long stream_ull = 0;
if (!PyArg_ParseTuple(args, "KK", &ev_ull, &stream_ull)) {
return nullptr;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUevent ev = reinterpret_cast<CUevent>(static_cast<uintptr_t>(ev_ull));
CUstream stream = reinterpret_cast<CUstream>(static_cast<uintptr_t>(stream_ull));
CUresult res = cuEventRecord(ev, stream);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuEventRecord failed", res);
}
Py_RETURN_NONE;
}
PyObject *event_synchronize(PyObject *, PyObject *args) {
unsigned long long ev_ull = 0;
if (!PyArg_ParseTuple(args, "K", &ev_ull)) {
return nullptr;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUevent ev = reinterpret_cast<CUevent>(static_cast<uintptr_t>(ev_ull));
CUresult res = cuEventSynchronize(ev);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuEventSynchronize failed", res);
}
Py_RETURN_NONE;
}
PyObject *event_elapsed_ms(PyObject *, PyObject *args) {
unsigned long long start_ull = 0;
unsigned long long end_ull = 0;
if (!PyArg_ParseTuple(args, "KK", &start_ull, &end_ull)) {
return nullptr;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUevent start = reinterpret_cast<CUevent>(static_cast<uintptr_t>(start_ull));
CUevent end = reinterpret_cast<CUevent>(static_cast<uintptr_t>(end_ull));
float ms = 0.0f;
CUresult res = cuEventElapsedTime(&ms, start, end);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuEventElapsedTime failed", res);
}
return PyFloat_FromDouble(static_cast<double>(ms));
}
PyObject *launch_tensor_metric_kernel(PyObject *, PyObject *args, PyObject *kwargs) {
unsigned long long kernel_ull = 0;
unsigned long long stream_ull = 0;
unsigned long long device_ptr_ull = 0;
unsigned long long device_offset_ptr_ull = 0;
unsigned long long size_words_ull = 0;
unsigned long long metric_id_ull = 0;
unsigned long long metric_value_ptr_ull = 0;
unsigned long long metric_value_size_ull = 0;
unsigned long long iters_ull = 1;
static const char *kwnames[] = {"kernel", "stream", "device_ptr", "device_offset_ptr", "size_words",
"metric_id", "metric_value_ptr", "metric_value_size", "iters", nullptr};
if (!PyArg_ParseTupleAndKeywords(
args, kwargs, "KKKKKKKK|K", const_cast<char **>(kwnames), &kernel_ull, &stream_ull,
&device_ptr_ull, &device_offset_ptr_ull, &size_words_ull, &metric_id_ull,
&metric_value_ptr_ull, &metric_value_size_ull, &iters_ull)) {
return nullptr;
}
if (iters_ull == 0) {
Py_RETURN_NONE;
}
std::string init_err;
if (!ensure_driver_initialized(&init_err)) {
return PyErr_Format(PyExc_RuntimeError, "%s", init_err.c_str());
}
CUfunction kernel = reinterpret_cast<CUfunction>(static_cast<uintptr_t>(kernel_ull));
CUstream stream = reinterpret_cast<CUstream>(static_cast<uintptr_t>(stream_ull));
CUdeviceptr device_ptr = static_cast<CUdeviceptr>(static_cast<uintptr_t>(device_ptr_ull));
CUdeviceptr device_offset_ptr = static_cast<CUdeviceptr>(static_cast<uintptr_t>(device_offset_ptr_ull));
uint64_t size_words = static_cast<uint64_t>(size_words_ull);
uint64_t metric_id = static_cast<uint64_t>(metric_id_ull);
CUdeviceptr metric_value_ptr = static_cast<CUdeviceptr>(static_cast<uintptr_t>(metric_value_ptr_ull));
uint64_t metric_value_size = static_cast<uint64_t>(metric_value_size_ull);
void *global_scratch_ptr = nullptr;
void *profile_scratch_ptr = nullptr;
void *kernel_params[] = {reinterpret_cast<void *>(&device_ptr),
reinterpret_cast<void *>(&device_offset_ptr),
reinterpret_cast<void *>(&size_words),
reinterpret_cast<void *>(&metric_id),
reinterpret_cast<void *>(&metric_value_ptr),
reinterpret_cast<void *>(&metric_value_size),
reinterpret_cast<void *>(&global_scratch_ptr),
reinterpret_cast<void *>(&profile_scratch_ptr)};
for (unsigned long long i = 0; i < iters_ull; ++i) {
CUresult res =
cuLaunchKernel(kernel, 1, 1, 1, 128, 1, 1, 0, stream, kernel_params, nullptr);
if (res != CUDA_SUCCESS) {
return raise_cuda("cuLaunchKernel failed", res);
}
}
Py_RETURN_NONE;
}
PyMethodDef kMethods[] = {
{"alloc_pinned_host_u64", (PyCFunction)alloc_pinned_host_u64,
METH_VARARGS | METH_KEYWORDS,
"alloc_pinned_host_u64(nelems: int, zero: bool=True, write_combined: bool=False) -> capsule\n"
"Allocates mapped, pinned host memory (uint64 elements) and returns a capsule."},
{"buf_host_ptr", buf_host_ptr, METH_VARARGS,
"buf_host_ptr(capsule) -> int\nReturns host pointer as integer."},
{"buf_device_ptr", buf_device_ptr, METH_VARARGS,
"buf_device_ptr(capsule) -> int\nReturns device-mapped pointer as integer."},
{"buf_nelems", buf_nelems, METH_VARARGS,
"buf_nelems(capsule) -> int\nReturns number of uint64 elements."},
{"fill_u64", fill_u64, METH_VARARGS,
"fill_u64(capsule, value: int) -> None\nFills the buffer (on host) with value."},
{"iota_u64", iota_u64, METH_VARARGS,
"iota_u64(capsule, start: int) -> None\nFills the buffer with start, start+1, ... (uint64)."},
{"read_u64", read_u64, METH_VARARGS,
"read_u64(capsule, idx: int) -> int\nReads one uint64 from host buffer."},
{"write_u64", write_u64, METH_VARARGS,
"write_u64(capsule, idx: int, value: int) -> None\nWrites one uint64 into host buffer."},
{"stream_synchronize", stream_synchronize, METH_VARARGS,
"stream_synchronize(stream: int) -> None\nSynchronizes a CUDA stream handle."},
{"event_create", (PyCFunction)event_create, METH_VARARGS | METH_KEYWORDS,
"event_create(disable_timing: bool=False) -> int\nCreates a CUDA event and returns its handle."},
{"event_destroy", event_destroy, METH_VARARGS,
"event_destroy(event: int) -> None\nDestroys a CUDA event."},
{"event_record", event_record, METH_VARARGS,
"event_record(event: int, stream: int) -> None\nRecords an event on a stream."},
{"event_synchronize", event_synchronize, METH_VARARGS,
"event_synchronize(event: int) -> None\nSynchronizes an event."},
{"event_elapsed_ms", event_elapsed_ms, METH_VARARGS,
"event_elapsed_ms(start: int, end: int) -> float\nReturns elapsed time between events (ms)."},
{"launch_tensor_metric_kernel", (PyCFunction)launch_tensor_metric_kernel,
METH_VARARGS | METH_KEYWORDS,
"launch_tensor_metric_kernel(kernel: int, stream: int, device_ptr: int, device_offset_ptr: int,\n"
" size_words: int, metric_id: int, metric_value_ptr: int,\n"
" metric_value_size: int, iters: int=1) -> None\n"
"Launches the Triton-compiled tensor_metric_kernel via cuLaunchKernel."},
{nullptr, nullptr, 0, nullptr},
};
PyModuleDef kModule = {
PyModuleDef_HEAD_INIT,
"tensor_metric_ext",
"Small CUDA driver helpers for reproducing tensor_metric_kernel behavior.",
-1,
kMethods,
nullptr,
nullptr,
nullptr,
nullptr,
};
} // namespace
extern "C" PyObject *PyInit_tensor_metric_ext(void) { return PyModule_Create(&kModule); }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment