cuLA:用 CUDA 重写线性注意力
什么是 cuLA
cuLA(CUDA Linear Attention)是一套面向线性注意力变体的高性能 CUDA 内核库,专为 GLA、KDA、GDN、Lightning Attention 等新一代注意力机制提供高效实现。
其核心价值在于,将这些原本复杂且工程实现门槛较高的算法,转化为可直接调用的高性能算子,使开发者无需深入底层优化,也能在大模型中高效使用线性注意力能力,显著提升长上下文场景下的推理效率与吞吐表现。
cuLA 基于 CuTe DSL 与 CUTLASS C++ 进行手工优化,针对 NVIDIA 新一代 GPU 架构进行了深度适配。同时,cuLA 在接口层与 flash-linear-attention(FLA)保持一致,开发者只需修改一行 import 即可完成替换,实现低成本接入。当前 cuLA 以独立库形式迭代,未来将通过 FLA 的 kernel dispatch 机制实现统一调度,进一步降低使用门槛并提升工程可维护性。
cuLA 开源地址:https://github.com/inclusionAI/cuLA
为什么做 cuLA
线性注意力(Linear Attention)正在成为长序列建模的核心范式,通过将复杂度从 降到 ,使百万级 token 的建模成为可能。GLA 、KDA (Kimi Delta Attention) 、GDN 、Lightning Attention 等变体进一步引入了门控、delta 更新和分块分解,在表达力和效率之间取得了更好的平衡。
然而,这些算法的高性能实现依然是瓶颈。现有的 Triton 内核 虽然开发效率高,但在寄存器分配、共享内存流水线、warp 特化等方面距离硬件极限仍有差距——尤其是在 Hopper 和 Blackwell 这两代引入了 TMA 、wgmma/UMMA 等新硬件特性的 GPU 上。
cuLA (CUDA Linear Attention) 正是为此而生。我们用 CuTe DSL 和 CUTLASS C++ 手写线性注意力内核,目标是在最新 GPU 上榨取每一分性能。
CUDA 线性注意力算子这件事,总有人会做。我们想的是——为什么不能是我们?说实话,团队里大多数人在启动 cuLA 之前几乎没有 CUDA 开发经验。但这并不妨碍我们试一试。从读第一行 CUTLASS 源码,到跑通第一个比 Triton 更快的内核,这个过程充满了乐趣。事实证明,”试一试”本身就是最好的起点。我们希望把这种心态传递出去:如果你也对 GPU 编程感到好奇,哪怕从零开始,cuLA 都欢迎你加入。
更多细节请参考 cuLA GitHub README 。
性能数据
所有 benchmark 均以 flash-linear-attention (FLA) v0.4.2 的 Triton 实现为基线。每组 configs 覆盖了不同 B (batch size)、S (sequence length)、H (head count)、D (head dimension) 的多种组合,以评估不同负载下的加速效果。
| 内核 | 场景 | 平均加速比 | 峰值加速比 |
|---|---|---|---|
| CP 兼容 KDA Forward (Blackwell SM10X) | fixed (10 configs) | 1.45x | 1.57x |
| CP 兼容 KDA Forward (Blackwell SM10X) | varlen (18 configs) | 1.32x | 1.35x |
| KDA Fused Forward (Hopper SM90) | fixed & varlen (28 configs) | 1.52x | 2.45x |
| Lightning Attention Prefill (Blackwell SM10X) | varlen 126 configs | 1.45x | 2.19x |
这些加速主要来自硬件特性的利用——TMA 异步拷贝、WGMMA/UMMA 指令、多级流水线、warp specialization。
完整 benchmark 数据请参考:Blackwell GB300 | Blackwell GB200 | Hopper H200
Status & Roadmap
General
- Integrate into flash-linear-attention via FLA’s kernel dispatch mechanism
- Polynomial approximation to mitigate the exponential bottleneck, as in Flash-Attention-4
- Larger chunk size and 2-CTA on SM10X for improved throughput
- Continuous optimization via agentic methods such as AVO
- Support for more algorithms
- Small B/H/S optimizations
- Support BF16 Beta Input.
Train
-
Modular KDA Forward (SM10X, compatible with Kimi CP )
- kda chunk intra
- chunk gated delta h
- recompute wu
- chunk fwd o
-
Modular GDN Forward / Backward Kernels (compatible with Kimi CP )
-
Backward pass optimizations
-
Kernel-level compute-communication overlapping for CP linear attention kernels (via nvshmem )
Inference
- Lightning prefill kernel (SM10X)
- Lightning decode kernel (SM90 & SM10X)
- Fused KDA prefill kernel (SM90)
- Fused KDA prefill kernel (SM10X)
- MTP support
- More aggressive fusion of small neighboring kernels like cumsum for inference scenarios
技术选型:为什么是 CuTe DSL + CUTLASS C++
我们选择 NVIDIA 的 CuTe DSL 作为主要开发语言,辅以 CUTLASS C++ 用于需要更精细控制的场景。这个选择基于几个考量:
- 硬件特性直达:无论是 CuTe DSL 还是 CUTLASS C++,都直接暴露 TMA、wgmma/UMMA、cluster launch 等新一代 GPU 特性,不存在 Triton 编译器的抽象屏障
- 编译速度碾压:这是我们特别看重 CuTe DSL 的一点——相比 CUTLASS C++ 动辄几分钟的编译时间,CuTe DSL 的编译速度快一个数量级以上。这意味着”改一行代码、跑一次 benchmark”的迭代周期从分钟级缩短到秒级,特别适合通过 AVO 等 agentic 方法进行自动化的内核迭代优化
由于 CuTe DSL 目前仍在快速发展中,部分 API 尚未完全覆盖所有场景,因此项目中部分内核为了方便仍然沿用了 CUTLASS C++ 实现。随着 CuTe DSL 的成熟,我们会逐步将更多内核迁移到 DSL 上。
我们需要社区的力量
坦率地说,CUDA 内核调优的工作量比 Triton 大一个数量级。cuLA 目前仍处于早期阶段,有大量优化空间和待实现的功能。我们真诚地邀请社区一起参与:
欢迎你的贡献
无论你的背景如何,都有适合参与的方向:
- CUDA / GPU 专家:内核性能调优、新架构适配、流水线设计
- 算法研究者:新的线性注意力变体实现、数值稳定性优化
- 系统工程师:分布式通信重叠、推理框架集成 (vLLM / SGLang)
- 测试与文档:完善测试覆盖、性能回归检测、使用文档
- 初学者:我们欢迎对 CUDA 零基础但学习能力强的同学,这是一个绝佳的 GPU 编程实战机会
项目完全开源,所有讨论和决策在 GitHub 上公开进行。
线性注意力的算法创新正在加速涌现,但如果没有高效的工程实现,再好的算法也只能停留在论文里。每一份社区贡献——无论是优化一个内核、适配一种新算法,还是修一个 bug——都在帮助线性注意力的研究者更快地验证想法、推进实验,让这条通往 AGI 的路走得更快一些。
cuLA 仍处于早期阶段,内核覆盖和性能优化都还有很大空间。我们深知,仅凭一个团队的力量远远不够——开源社区的参与才是这个项目真正的加速器。希望 cuLA 能成为一个开放的协作平台,让更多人一起完善线性注意力的基础设施,为 AGI 领域的研究者和工程师添砖加瓦。
快速上手
# 克隆并初始化
git clone https://github.com/InclusionAI/cuLA.git
git submodule update --init --recursive
# 安装依赖
pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu129
pip install -e third_party/flash-linear-attention
pip install -e . --no-build-isolation要求:Python 3.12+, CUDA Toolkit 12.9+, PyTorch 2.9.1+,支持 Hopper (SM90) 和 Blackwell (SM10X) GPU。
cuLA 设计为 flash-linear-attention (FLA) 的 drop-in replacement,只需改一行 import:
# 之前(FLA Triton)
from fla.ops.kda import chunk_kda
# 之后(cuLA CUDA)
from cula.kda import chunk_kda其余代码完全不变。cuLA 的终极目标是通过 FLA 的 kernel dispatch 机制自动选择最优实现,让用户完全无感知地享受加速。
致谢
cuLA 站在巨人的肩膀上——flash-linear-attention 、CUTLASS 、CuTe DSL 、FlashInfer 、Flash-Attention 、FlashMLA 。感谢 FLA-org 和 NVIDIA 的卓越工作。
联系我们
如果你对实习或全职机会感兴趣,也欢迎联系
如果觉得有用,请给我们一个 Star 。