Created
January 27, 2026 10:10
-
-
Save kohya-ss/94194d2d0bf3b9eb780419707282d540 to your computer and use it in GitHub Desktop.
PyTorch 2.10 patch for https://github.com/mengqin/SageAttention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/csrc/fused/fused.cu b/csrc/fused/fused.cu | |
| index fb8b9f1..571a9e6 100644 | |
| --- a/csrc/fused/fused.cu | |
| +++ b/csrc/fused/fused.cu | |
| @@ -14,6 +14,12 @@ | |
| * limitations under the License. | |
| */ | |
| +#include <cuda/barrier> | |
| +#include <cuda/pipeline> | |
| +#include <cuda/std/atomic> | |
| +#include <cuda_fp16.h> | |
| +#include <cuda_pipeline_primitives.h> | |
| +#include <cuda_runtime.h> | |
| #include <ATen/cuda/CUDAContext.h> | |
| #include <torch/extension.h> | |
| diff --git a/csrc/fused/pybind.cpp b/csrc/fused/pybind.cpp | |
| index bffdb06..78c0c1e 100644 | |
| --- a/csrc/fused/pybind.cpp | |
| +++ b/csrc/fused/pybind.cpp | |
| @@ -14,6 +14,9 @@ | |
| * limitations under the License. | |
| */ | |
| +// #include <cuda/barrier> | |
| +// #include <cuda/std/atomic> | |
| +#include <cuda_runtime.h> | |
| #include <torch/extension.h> | |
| #include <cuda_fp16.h> | |
| #include "fused.h" | |
| diff --git a/csrc/qattn/attn_cuda_sm80.h b/csrc/qattn/attn_cuda_sm80.h | |
| index 83e830d..f27efc2 100644 | |
| --- a/csrc/qattn/attn_cuda_sm80.h | |
| +++ b/csrc/qattn/attn_cuda_sm80.h | |
| @@ -14,6 +14,14 @@ | |
| * limitations under the License. | |
| */ | |
| +#include <cuda_runtime.h> | |
| +#include <cuda/barrier> | |
| +#include <cuda/pipeline> | |
| +#include <cuda/std/atomic> | |
| +#include <cuda_fp16.h> | |
| +#ifdef __CUDACC__ | |
| +#include <cuda_pipeline_primitives.h> | |
| +#endif | |
| #include <torch/extension.h> | |
| torch::Tensor qk_int8_sv_f16_accum_f32_attn(torch::Tensor query, | |
| diff --git a/csrc/qattn/attn_cuda_sm89.h b/csrc/qattn/attn_cuda_sm89.h | |
| index f45f8d7..5cd9d6c 100644 | |
| --- a/csrc/qattn/attn_cuda_sm89.h | |
| +++ b/csrc/qattn/attn_cuda_sm89.h | |
| @@ -14,6 +14,14 @@ | |
| * limitations under the License. | |
| */ | |
| +#include <cuda_runtime.h> | |
| +#include <cuda/barrier> | |
| +#include <cuda/pipeline> | |
| +#include <cuda/std/atomic> | |
| +#include <cuda_fp16.h> | |
| +#ifdef __CUDACC__ | |
| +#include <cuda_pipeline_primitives.h> | |
| +#endif | |
| #include <torch/extension.h> | |
| torch::Tensor qk_int8_sv_f8_accum_f32_attn(torch::Tensor query, | |
| diff --git a/csrc/qattn/attn_cuda_sm90.h b/csrc/qattn/attn_cuda_sm90.h | |
| index 822ab7f..3c24122 100644 | |
| --- a/csrc/qattn/attn_cuda_sm90.h | |
| +++ b/csrc/qattn/attn_cuda_sm90.h | |
| @@ -14,6 +14,14 @@ | |
| * limitations under the License. | |
| */ | |
| +#include <cuda_runtime.h> | |
| +#include <cuda/barrier> | |
| +#include <cuda/pipeline> | |
| +#include <cuda/std/atomic> | |
| +#include <cuda_fp16.h> | |
| +#ifdef __CUDACC__ | |
| +#include <cuda_pipeline_primitives.h> | |
| +#endif | |
| #include <torch/extension.h> | |
| torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf( | |
| diff --git a/csrc/qattn/attn_utils.cuh b/csrc/qattn/attn_utils.cuh | |
| index 6f8b87c..57921d6 100644 | |
| --- a/csrc/qattn/attn_utils.cuh | |
| +++ b/csrc/qattn/attn_utils.cuh | |
| @@ -15,6 +15,7 @@ | |
| */ | |
| #pragma once | |
| +#include <cuda_runtime.h> | |
| #include "../utils.cuh" | |
| #include <cuda_fp16.h> | |
| #include <cuda_pipeline_primitives.h> | |
| diff --git a/csrc/qattn/pybind_sm80.cpp b/csrc/qattn/pybind_sm80.cpp | |
| index 0d5e71c..98e8a23 100644 | |
| --- a/csrc/qattn/pybind_sm80.cpp | |
| +++ b/csrc/qattn/pybind_sm80.cpp | |
| @@ -13,7 +13,9 @@ | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| - | |
| +#include <cuda_fp16.h> | |
| +// #include <cuda_pipeline_primitives.h> | |
| +#include <cuda_runtime.h> | |
| #include <pybind11/pybind11.h> | |
| #include <torch/extension.h> | |
| #include "attn_cuda_sm80.h" | |
| diff --git a/csrc/qattn/pybind_sm89.cpp b/csrc/qattn/pybind_sm89.cpp | |
| index 559179c..a2bcc37 100644 | |
| --- a/csrc/qattn/pybind_sm89.cpp | |
| +++ b/csrc/qattn/pybind_sm89.cpp | |
| @@ -13,7 +13,9 @@ | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| - | |
| +#include <cuda_fp16.h> | |
| +// #include <cuda_pipeline_primitives.h> | |
| +#include <cuda_runtime.h> | |
| #include <pybind11/pybind11.h> | |
| #include <torch/extension.h> | |
| #include "attn_cuda_sm89.h" | |
| diff --git a/csrc/qattn/pybind_sm90.cpp b/csrc/qattn/pybind_sm90.cpp | |
| index 8900b64..d295d24 100644 | |
| --- a/csrc/qattn/pybind_sm90.cpp | |
| +++ b/csrc/qattn/pybind_sm90.cpp | |
| @@ -13,7 +13,9 @@ | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| - | |
| +#include <cuda_fp16.h> | |
| +// #include <cuda_pipeline_primitives.h> | |
| +#include <cuda_runtime.h> | |
| #include <pybind11/pybind11.h> | |
| #include <torch/extension.h> | |
| #include "attn_cuda_sm90.h" | |
| diff --git a/csrc/utils.cuh b/csrc/utils.cuh | |
| index ca83d2c..bc17eca 100644 | |
| --- a/csrc/utils.cuh | |
| +++ b/csrc/utils.cuh | |
| @@ -15,6 +15,9 @@ | |
| */ | |
| #pragma once | |
| +#include <cuda/barrier> | |
| +#include <cuda/pipeline> | |
| +#include <cuda/std/atomic> | |
| #include <torch/extension.h> | |
| #define CHECK_CUDA(x) \ | |
| diff --git a/setup.py b/setup.py | |
| index 41b9a54..73295ef 100644 | |
| --- a/setup.py | |
| +++ b/setup.py | |
| @@ -48,8 +48,10 @@ if not SKIP_CUDA_BUILD: | |
| SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "10.0", "12.0", "12.1"} | |
| # Compiler flags. | |
| + CCCL_INCLUDE_DIR = os.path.join(CUDA_HOME, "include", "cccl") | |
| if os.name == "nt": | |
| CXX_FLAGS = ["/Zi", "/openmp", "/std:c++17", "/Zc:__cplusplus", "/DENABLE_BF16", "/MD", "/permissive-"] | |
| + CXX_FLAGS += ["/DNOMINMAX"] | |
| NVCC_FLAGS = [ | |
| "-O3", | |
| "-std=c++17", | |
| @@ -192,6 +194,7 @@ if not SKIP_CUDA_BUILD: | |
| "csrc/qattn/pybind_sm80.cpp", | |
| "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", | |
| ], | |
| + include_dirs=[CCCL_INCLUDE_DIR], | |
| extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, | |
| ) | |
| ) | |
| @@ -210,6 +213,7 @@ if not SKIP_CUDA_BUILD: | |
| "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", | |
| "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu", | |
| ], | |
| + include_dirs=[CCCL_INCLUDE_DIR], | |
| extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, | |
| ) | |
| ) | |
| @@ -228,6 +232,7 @@ if not SKIP_CUDA_BUILD: | |
| "csrc/qattn/pybind_sm90.cpp", | |
| "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", | |
| ], | |
| + include_dirs=[CCCL_INCLUDE_DIR], | |
| extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS_SM90}, | |
| extra_link_args=cuda_lib, | |
| ) | |
| @@ -237,6 +242,7 @@ if not SKIP_CUDA_BUILD: | |
| CUDAExtension( | |
| name="sageattention._fused", | |
| sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"], | |
| + include_dirs=[CCCL_INCLUDE_DIR], | |
| extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, | |
| ) | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment