FlashKDA: MoonshotAI가 공개한, Kimi Delta Attention을 위한 고성능 CUDA 커널 라이브러리

FlashKDA 소개

대규모 언어 모델의 추론 효율을 높이기 위한 선형 어텐션(Linear Attention) 연구가 빠르게 발전하고 있습니다. Transformer의 표준 소프트맥스 어텐션은 시퀀스 길이 T 에 대해 O(T^2) 의 연산 복잡도를 가지기 때문에, 긴 컨텍스트 처리 시 메모리와 연산 비용이 크게 증가합니다. 이를 해결하기 위해 Mamba, RWKV, RetNet 등 선형 복잡도 아키텍처들이 등장했으며, 최근에는 이런 모델들의 핵심 연산을 CUDA 수준에서 하드웨어에 최적화하는 작업이 중요해졌습니다. 특히 Triton으로 작성된 선형 어텐션 커널은 구현이 유연하지만, CUTLASS 기반의 고도로 최적화된 CUDA 커널에 비해 성능 차이가 존재합니다.

FlashKDA(Flash Kimi Delta Attention)는 MoonshotAI(Kimi 개발사)가 공개한, Kimi 모델에서 사용하는 Delta Attention(KDA, Kimi Delta Attention) 연산을 위한 고성능 CUDA 커널 라이브러리입니다. NVIDIA CUTLASS를 기반으로 구축되었으며, Hopper(H100) 및 그 이상의 GPU 아키텍처(SM90+)를 대상으로 최적화되었습니다. Flash Attention이 소프트맥스 어텐션을 하드웨어 수준에서 최적화한 것처럼, FlashKDA는 KDA 연산을 동일한 방식으로 가속화합니다. 2026년 4월 22일에는 FlashKDA v1의 설계 결정 사항을 상세히 설명하는 딥다이브 블로그 포스트도 공개되었습니다.

FlashKDAflash-linear-attention 라이브러리와 통합되어, chunk_kda 함수 호출 시 자동으로 FlashKDA 커널로 디스패치됩니다. 기존 Triton 구현 경로를 대체하여 H100 등 최신 GPU에서 더 높은 성능을 제공합니다. FLA_FLASH_KDA=0 환경 변수를 설정하면 언제든지 Triton 경로로 되돌릴 수 있어, 디버깅이나 하위 호환성이 필요한 경우에 유용합니다.

Flash Linear Attention과 비교

FlashKDA는 기존의 널리 사용되는 Flash Linear Attention과 비교했을 때, 특히 청크 크기(Chunk Size)의 선택에서 뚜렷한 차별점을 보입니다. Flash Linear Attention이 일반적으로 64의 청크 크기를 채택하는 반면, FlashKDA의 초기 버전(v1)은 청크 크기를 16으로 과감하게 줄여서 설계되었습니다. 이러한 결정은 16x16 행렬의 역행렬을 계산하는 것이 64x64 행렬을 처리하는 것보다 연산 비용이 극적으로 저렴하기 때문입니다. 작아진 행렬 덕분에 추가적인 복잡한 분해(Decomposition) 과정 없이 노이만 급수 전개(Neumann-series expansion)만으로 직접 역행렬을 구할 수 있어 연산 파이프라인이 크게 간소화되는 장점이 있습니다.

수치적 안정성 측면에서도 작은 청크 크기는 큰 이점을 제공합니다. FlashKDA는 게이트(Gate)의 하한값을 -5로 설정하고 16의 청크 크기를 유지함으로써, 지수 연산 결과의 범위를 BFloat16(bf16) 데이터 타입이 표현할 수 있는 한계 내로 안정적으로 맞췄습니다. 결과적으로 큰 청크 크기를 사용할 때 불가피하게 요구되었던 복잡한 청크 내 리스케일링(Intra-chunk rescaling) 트릭을 완전히 제거할 수 있었습니다. 더불어 이 모든 연산 구조는 최신 하드웨어뿐만 아니라 SM80 아키텍처의 MMA 명령어들에도 완벽하게 매핑되므로 높은 이식성(Portability)을 자랑합니다.

FlashKDA의 KDA 연산 구조

KDA(Kimi Delta Attention)는 선형 어텐션의 변형으로, 일반적인 선형 어텐션 수식에 게이팅(Gating), 베타(Beta) 스케일링, L2 정규화 등 추가 요소가 결합된 형태입니다. FlashKDA의 커널 API인 flash_kda.fwd는 다음 파라미터들을 받습니다:

파라미터 자료형 형태 설명
q bf16 [B, T, H, K] 쿼리(Query)
k bf16 [B, T, H, K] 키(Key)
v bf16 [B, T, H, V] 밸류(Value)
g bf16 [B, T, H, K] 활성화 전 게이트
beta bf16 [B, T, H] 베타 로짓(내부적으로 시그모이드 적용)
A_log fp32 [H] 로그 게이트 파라미터
dt_bias fp32 [H, K] 게이트 바이어스
lower_bound float scalar 게이트 하한(-5.0 ~ 0 범위)
initial_state bf16/fp32 [B, H, V, K] (선택) 초기 순환 상태
final_state bf16/fp32 [B, H, V, K] (선택, 출력) 최종 순환 상태

이 파라미터 구조에서 볼 수 있듯이 KDA는 Q, K, V 외에도 게이트(g), 베타(beta), 로그 게이트(A_log), 게이트 바이어스(dt_bias) 등 여러 추가 파라미터로 어텐션 패턴을 세밀하게 제어합니다. 내부적으로는 QK L2 정규화(use_qk_l2norm_in_kernel), 베타 시그모이드(use_beta_sigmoid_in_kernel), 안전 게이트(safe_gate) 등의 연산이 커널 내에 융합(Fused)되어 있어 메모리 왕복 횟수를 최소화합니다.

FlashKDA 주요 아키텍처 및 특징

커널 퓨전(Kernel Fusion) 전략의 분리

초기 프로토타입은 단일 퓨전 커널(Fused Kernel) 구조를 채택했으나, 최종 릴리스에서는 병렬성(Parallelism)의 축을 기준으로 연산을 K1과 K2라는 두 개의 독립된 커널로 분할했습니다.

  • K1 커널 (토큰 병렬성 기반): 게이트 활성화, L2 정규화, 감쇠(Decay) 적용 및 행렬 역연산을 수행합니다.
  • K2 커널 (헤드 병렬성 기반): 청크 단위의 델타 룰 순환(Recurrence)과 결과값 투영(Output projection)을 전담합니다.

단일 커널 환경에서는 K1의 높은 토큰 병렬성이 K2의 순환 과정 병목에 막혀 다수의 스트리밍 멀티프로세서(SM)가 유휴 상태로 방치되는 문제가 있었습니다. 파이프라인을 이처럼 두 단계로 분리함으로써 각 단계를 독립적으로 튜닝할 수 있게 되었고, 15% 이상의 엔드투엔드(End-to-End) 연산 속도 향상을 이끌어냈습니다.

수치 정밀도(Numerical Precision) 최적화

메모리 절약과 연산 효율을 극대화하기 위해 온칩(On-chip) 순환 상태(Recurrent state)를 bf16 형식으로 저장합니다. 이를 통해 공유 메모리 사용 공간을 절반으로 줄이고, 매 상태 업데이트마다 발생하는 fp32에서 bf16으로의 불필요한 형변환 오버헤드를 완전히 제거했습니다. 누적(Accumulator) 자체는 fp32 FMA 명령어로 연산하므로 정확도 손실은 발생하지 않습니다. 또한 16x16 행렬의 역연산은 bf16 대신 fp16으로 수행되는데, 역행렬 요소들의 값이 [-1, 1] 구간에 머물기 때문에 좁은 동적 범위를 지닌 fp16으로도 충분하며, 오히려 노이만 급수 전개에 여유 공간(Headroom)을 주어 정확도를 더욱 향상시킵니다.

저수준(Low-level) 명령어 및 메모리 활용

NVIDIA GPU의 연산 처리량을 한계까지 끌어올리기 위한 다양한 세부 최적화 기법들이 적용되었습니다.

  • 수학 연산 가속: 시그모이드(Sigmoid) 함수는 PTX의 tanh.approx.f32 명령어로 구현되었으며, 지수 연산은 밑(base)을 2로 변환한 후 빠른 ex2.approx.ftz.f32 명령어를 활용합니다.
  • 점유율 향상: K1 커널에서는 수명 주기가 겹치지 않는 공유 메모리 변수들을 공용체(Union)로 묶어 재사용하고, __launch_bounds__(256, 8) 지시어를 통해 레지스터를 미세하게 디스크로 밀어내는(Spilling) 대신 SM당 스레드 블록 점유율을 크게 높였습니다.
  • 명령어 기반 전치: K2 커널에서는 MOVM_T 명령어를 이용해 레지스터 파일 내부에서 직접 피연산자를 전치(Transpose) 처리함으로써 느린 공유 메모리로의 왕복 접근을 완전히 제거했습니다.

FlashKDA 설치 및 사용법

시스템 요구사항

  • NVIDIA SM90 이상 GPU (H100, H20 등)
  • CUDA 12.9 이상
  • PyTorch 2.4 이상

설치

git clone https://github.com/MoonshotAI/FlashKDA.git flash-kda
cd flash-kda
git submodule update --init --recursive
pip install -v .

flash-linear-attention 백엔드로 사용

pip install -U flash-linear-attention  # >= 0.5.0 필요
import torch
from fla.ops.kda import chunk_kda

# FlashKDA가 자동으로 디스패치됨
with torch.inference_mode():
    out, final_state = chunk_kda(
        q=q, k=k, v=v, g=g, beta=beta,
        scale=scale,
        initial_state=h0,
        output_final_state=True,
        use_gate_in_kernel=True,
        use_qk_l2norm_in_kernel=True,
        use_beta_sigmoid_in_kernel=True,
        safe_gate=True,
        A_log=A_log, dt_bias=dt_bias,
        lower_bound=lower_bound,
        transpose_state_layout=True,
        cu_seqlens=cu_seqlens,
    )

디스패치 디버깅

import logging
logging.basicConfig(level=logging.INFO)
# FlashKDA 적중 시: [FLA Backend] kda.chunk_kda -> flashkda
# 실패 시: ... rejected: <reason>

Triton 경로로 전환

export FLA_FLASH_KDA=0  # FlashKDA 비활성화, Triton 경로 사용

라이선스

FlashKDA 프로젝트의 라이선스는 MIT 라이선스로 제공됩니다.

:scroll: FlashKDA 딥다이브 블로그

:github: FlashKDA 프로젝트 GitHub 저장소

더 읽어보기




이 글은 GPT 모델로 정리한 글을 바탕으로 한 것으로, 원문의 내용 또는 의도와 다르게 정리된 내용이 있을 수 있습니다. 관심있는 내용이시라면 원문도 함께 참고해주세요! 읽으시면서 어색하거나 잘못된 내용을 발견하시면 덧글로 알려주시기를 부탁드립니다. :hugs:

:pytorch:파이토치 한국 사용자 모임:south_korea:이 정리한 이 글이 유용하셨나요? 회원으로 가입하시면 주요 글들을 이메일:love_letter:로 보내드립니다! (기본은 Weekly지만 Daily로 변경도 가능합니다.)

:wrapped_gift: 아래:down_right_arrow:쪽에 좋아요:+1:를 눌러주시면 새로운 소식들을 정리하고 공유하는데 힘이 됩니다~ :star_struck: