Tokamax: JAX와 Pallas 기반의 커스텀 가속 커널 라이브러리

Tokamax 소개

Tokamax는 NVIDIA GPU와 Google TPU에서 동작하는 커스텀 가속 커널(custom accelerator kernels)을 제공하는 Python 라이브러리입니다. 이 프로젝트는 최신 JAXPallas를 기반으로 만들어졌으며, 딥러닝 모델에서 중요한 연산을 최적화된 커널로 교체해 성능을 크게 끌어올릴 수 있도록 설계되었습니다. 단순히 제공된 커널만 사용하는 것이 아니라, 사용자가 직접 자신만의 커널을 작성하고 자동 튜닝(autotuning) 기능을 통해 최적화할 수 있다는 점이 핵심적인 특징입니다.

현재 Tokamax는 개발 초기 단계에 있으며, API 변경 가능성이 크고 일부 기능은 완전하지 않습니다. 하지만 제공되는 커널은 최신 연구 성과를 반영하고 있어, 연구자와 엔지니어가 GPU 및 TPU 환경에서 고성능 실험을 진행할 때 유용합니다. 예를 들어, Transformer 모델의 핵심 연산인 tokamax.dot_product_attention(FlashAttention 기반), tokamax.gated_linear_unit(Gated linear units (SwiGLU 등)), tokamax.layer_norm(Layer NormalizationRMSNorm) 같은 고성능 커널을 지원합니다.

Tokamax의 또 다른 장점은 GPU와 TPU 모두에서 동작하는 ragged_dot 연산을 지원한다는 점입니다. 이는 Mixture of Experts(MoE) 아키텍처에서 중요한 연산으로, 대규모 모델의 효율적 학습과 추론을 가능하게 합니다. 따라서 Tokamax는 최신 LLM과 딥러닝 연구에 적합한 라이브러리라 할 수 있습니다.

Tokamax는 단순히 커널을 제공하는 라이브러리가 아니라, 여러 구현체를 선택하거나 자동으로 최적화된 커널을 고르는 기능을 제공합니다. 예를 들어 implementation=None 옵션을 주면 Tokamax가 커널의 입력 크기와 실행 환경에 따라 최적의 구현체를 자동 선택합니다. 반면, implementation="mosaic"처럼 특정 구현체를 강제할 수도 있는데, 이는 최신 GPU 아키텍처에서 최적의 성능을 기대할 수 있으나, 일부 환경에서는 실행되지 않을 수 있습니다.

이는 단일 구현체만 제공하는 기존 커널 라이브러리와 달리, Tokamax가 JAX 생태계에서 다양한 하드웨어 특성을 반영하는 유연성을 제공한다는 점에서 차별화됩니다. 또한 자동 튜닝을 통해 동일한 코드라도 환경에 따라 최적의 성능을 발휘하도록 조정할 수 있다는 점에서, 기존 JAX/XLA 단일 구현 기반 커널보다 실무 활용도가 높습니다.

Tokamax 설치 및 사용 예시

설치

Tokamax는 PyPI와 GitHub 두 가지 방식으로 설치할 수 있습니다. 안정적인 버전을 원한다면 PyPI를, 최신 기능을 실험하고 싶다면 GitHub의 소스 코드로부터 설치하는 것을 권장합니다:

# PyPI의 안정적인 버전
pip install -U tokamax

# GitHub의 최신 기능 버전
pip install git+https://github.com/openxla/tokamax.git

기본 사용 예시

Tokamax는 JAX 코드와 자연스럽게 통합됩니다. 아래는 H100 GPU에서 Tokamax를 사용하는 예시입니다.

import jax
import jax.numpy as jnp
import tokamax

def loss(x, scale):
    x = tokamax.layer_norm(x, scale=scale, offset=None, implementation="triton")
    x = tokamax.dot_product_attention(x, x, x, implementation="xla_chunked")
    return jnp.sum(x)

f_grad = jax.jit(jax.grad(loss))

여기서 implementation=None으로 설정하면 Tokamax가 가장 적절한 커널을 선택합니다. 이는 환경에 따라 forward pass와 backward pass에서 다른 커널을 쓸 수도 있으며, 필요시 XLA 기본 구현체로 fallback 됩니다.

자동 튜닝(Autotuning)

Tokamax는 커널을 자동으로 최적화하는 기능을 제공합니다. 이 과정은 실행 시간이 다소 걸릴 수 있지만, 이후에는 직렬화(Serialize)하여 재사용할 수 있습니다:

autotune_result = tokamax.autotune(f_grad, x, scale)

with autotune_result:
    out = f_grad(x, scale)

이를 통해 동일한 환경에서 반복 실행 시 일관된 성능과 수치를 확보할 수 있습니다.

커널 직렬화

Tokamax는 커널을 StableHLO 포맷으로 직렬화할 수 있습니다. 다만 JAX의 기본 export와는 달리, Tokamax의 커널은 디바이스 독립성을 보장하지 않으며 특정 디바이스에 종속됩니다. 대신 6개월간의 backward compatibility를 보장합니다.

벤치마킹 도구

JAX Python 오버헤드는 실제 커널 실행 시간보다 크게 나타날 수 있습니다. Tokamax는 CUDA events, CUPTI 프로파일러 등 다양한 방법을 지원하여 실제 가속기 실행 시간만 측정할 수 있습니다:

from tokamax import benchmarking

f_std, args = benchmarking.standardize_function(f, kwargs={'x': x, 'scale': scale})
run = benchmarking.compile_benchmark(f_std, args)
bench = run(args, method='cuda_events')

이를 통해 모델 최적화 과정에서 더 정확한 성능 측정을 할 수 있습니다.

라이선스

Tokamax 프로젝트는 Apache License 2.0으로 배포되고 있으며, 상업적 사용을 포함한 자유로운 활용이 가능합니다. 다만 보증이 없으며 사용 시 책임은 사용자에게 있습니다.

:github: Tokamax GitHub 저장소




이 글은 GPT 모델로 정리한 글을 바탕으로 한 것으로, 원문의 내용 또는 의도와 다르게 정리된 내용이 있을 수 있습니다. 관심있는 내용이시라면 원문도 함께 참고해주세요! 읽으시면서 어색하거나 잘못된 내용을 발견하시면 덧글로 알려주시기를 부탁드립니다. :hugs:

:pytorch:파이토치 한국 사용자 모임:south_korea:이 정리한 이 글이 유용하셨나요? 회원으로 가입하시면 주요 글들을 이메일:love_letter:로 보내드립니다! (기본은 Weekly지만 Daily로 변경도 가능합니다.)

:wrapped_gift: 아래:down_right_arrow:쪽에 좋아요:+1:를 눌러주시면 새로운 소식들을 정리하고 공유하는데 힘이 됩니다~ :star_struck: