Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Created January 27, 2026 10:10
Show Gist options
  • Select an option

  • Save kohya-ss/94194d2d0bf3b9eb780419707282d540 to your computer and use it in GitHub Desktop.

Select an option

Save kohya-ss/94194d2d0bf3b9eb780419707282d540 to your computer and use it in GitHub Desktop.
diff --git a/csrc/fused/fused.cu b/csrc/fused/fused.cu
index fb8b9f1..571a9e6 100644
--- a/csrc/fused/fused.cu
+++ b/csrc/fused/fused.cu
@@ -14,6 +14,12 @@
* limitations under the License.
*/
+#include <cuda/barrier>
+#include <cuda/pipeline>
+#include <cuda/std/atomic>
+#include <cuda_fp16.h>
+#include <cuda_pipeline_primitives.h>
+#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
diff --git a/csrc/fused/pybind.cpp b/csrc/fused/pybind.cpp
index bffdb06..78c0c1e 100644
--- a/csrc/fused/pybind.cpp
+++ b/csrc/fused/pybind.cpp
@@ -14,6 +14,9 @@
* limitations under the License.
*/
+// #include <cuda/barrier>
+// #include <cuda/std/atomic>
+#include <cuda_runtime.h>
#include <torch/extension.h>
#include <cuda_fp16.h>
#include "fused.h"
diff --git a/csrc/qattn/attn_cuda_sm80.h b/csrc/qattn/attn_cuda_sm80.h
index 83e830d..f27efc2 100644
--- a/csrc/qattn/attn_cuda_sm80.h
+++ b/csrc/qattn/attn_cuda_sm80.h
@@ -14,6 +14,14 @@
* limitations under the License.
*/
+#include <cuda_runtime.h>
+#include <cuda/barrier>
+#include <cuda/pipeline>
+#include <cuda/std/atomic>
+#include <cuda_fp16.h>
+#ifdef __CUDACC__
+#include <cuda_pipeline_primitives.h>
+#endif
#include <torch/extension.h>
torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query,
diff --git a/csrc/qattn/attn_cuda_sm89.h b/csrc/qattn/attn_cuda_sm89.h
index f45f8d7..5cd9d6c 100644
--- a/csrc/qattn/attn_cuda_sm89.h
+++ b/csrc/qattn/attn_cuda_sm89.h
@@ -14,6 +14,14 @@
* limitations under the License.
*/
+#include <cuda_runtime.h>
+#include <cuda/barrier>
+#include <cuda/pipeline>
+#include <cuda/std/atomic>
+#include <cuda_fp16.h>
+#ifdef __CUDACC__
+#include <cuda_pipeline_primitives.h>
+#endif
#include <torch/extension.h>
torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query,
diff --git a/csrc/qattn/attn_cuda_sm90.h b/csrc/qattn/attn_cuda_sm90.h
index 822ab7f..3c24122 100644
--- a/csrc/qattn/attn_cuda_sm90.h
+++ b/csrc/qattn/attn_cuda_sm90.h
@@ -14,6 +14,14 @@
* limitations under the License.
*/
+#include <cuda_runtime.h>
+#include <cuda/barrier>
+#include <cuda/pipeline>
+#include <cuda/std/atomic>
+#include <cuda_fp16.h>
+#ifdef __CUDACC__
+#include <cuda_pipeline_primitives.h>
+#endif
#include <torch/extension.h>
torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(
diff --git a/csrc/qattn/attn_utils.cuh b/csrc/qattn/attn_utils.cuh
index 6f8b87c..57921d6 100644
--- a/csrc/qattn/attn_utils.cuh
+++ b/csrc/qattn/attn_utils.cuh
@@ -15,6 +15,7 @@
*/
#pragma once
+#include <cuda_runtime.h>
#include "../utils.cuh"
#include <cuda_fp16.h>
#include <cuda_pipeline_primitives.h>
diff --git a/csrc/qattn/pybind_sm80.cpp b/csrc/qattn/pybind_sm80.cpp
index 0d5e71c..98e8a23 100644
--- a/csrc/qattn/pybind_sm80.cpp
+++ b/csrc/qattn/pybind_sm80.cpp
@@ -13,7 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
+#include <cuda_fp16.h>
+// #include <cuda_pipeline_primitives.h>
+#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attn_cuda_sm80.h"
diff --git a/csrc/qattn/pybind_sm89.cpp b/csrc/qattn/pybind_sm89.cpp
index 559179c..a2bcc37 100644
--- a/csrc/qattn/pybind_sm89.cpp
+++ b/csrc/qattn/pybind_sm89.cpp
@@ -13,7 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
+#include <cuda_fp16.h>
+// #include <cuda_pipeline_primitives.h>
+#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attn_cuda_sm89.h"
diff --git a/csrc/qattn/pybind_sm90.cpp b/csrc/qattn/pybind_sm90.cpp
index 8900b64..d295d24 100644
--- a/csrc/qattn/pybind_sm90.cpp
+++ b/csrc/qattn/pybind_sm90.cpp
@@ -13,7 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
+#include <cuda_fp16.h>
+// #include <cuda_pipeline_primitives.h>
+#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attn_cuda_sm90.h"
diff --git a/csrc/utils.cuh b/csrc/utils.cuh
index ca83d2c..bc17eca 100644
--- a/csrc/utils.cuh
+++ b/csrc/utils.cuh
@@ -15,6 +15,9 @@
*/
#pragma once
+#include <cuda/barrier>
+#include <cuda/pipeline>
+#include <cuda/std/atomic>
#include <torch/extension.h>
#define CHECK_CUDA(x) \
diff --git a/setup.py b/setup.py
index 41b9a54..73295ef 100644
--- a/setup.py
+++ b/setup.py
@@ -48,8 +48,10 @@ if not SKIP_CUDA_BUILD:
SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "10.0", "12.0", "12.1"}
# Compiler flags.
+ CCCL_INCLUDE_DIR = os.path.join(CUDA_HOME, "include", "cccl")
if os.name == "nt":
CXX_FLAGS = ["/Zi", "/openmp", "/std:c++17", "/Zc:__cplusplus", "/DENABLE_BF16", "/MD", "/permissive-"]
+ CXX_FLAGS += ["/DNOMINMAX"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
@@ -192,6 +194,7 @@ if not SKIP_CUDA_BUILD:
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
+ include_dirs=[CCCL_INCLUDE_DIR],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
)
)
@@ -210,6 +213,7 @@ if not SKIP_CUDA_BUILD:
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu",
],
+ include_dirs=[CCCL_INCLUDE_DIR],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
)
)
@@ -228,6 +232,7 @@ if not SKIP_CUDA_BUILD:
"csrc/qattn/pybind_sm90.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
],
+ include_dirs=[CCCL_INCLUDE_DIR],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS_SM90},
extra_link_args=cuda_lib,
)
@@ -237,6 +242,7 @@ if not SKIP_CUDA_BUILD:
CUDAExtension(
name="sageattention._fused",
sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"],
+ include_dirs=[CCCL_INCLUDE_DIR],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment