MaxText, Jax와 Python으로 구현한 오픈소스 LLM 학습 프레임워크 (feat. Google)

MaxText, Google이 Jax와 Python으로 구현하여 공개한 오픈소스 LLM 학습 프레임워크

소개

Google이 이전에 공개한 대규모 언어 모델(LLM) 학습을 위한 오픈소스 프레임워크인 MaxText를 뒤늦게 발견하여 줏어왔습니다. Jax와 Python으로 구현되어 있으며, Google Cloud TPU와 GPU에서 최적화된 학습과 추론이 가능하다고 하는데요. 과연 어떤 특징들이 있는지, 다른 프레임워크들과는 어떻게 다른지 함께 살펴보시죠!

MaxText는 Google에서 개발한 오픈소스 LLM 프레임워크로, 순수 Python과 Jax로 작성되었습니다. Google Cloud TPU와 GPU에서 고성능 학습과 추론을 지원하는 것이 주요 특징인데요. Jax와 XLA 컴파일러 덕분에 간단하면서도 최적화가 잘 된 코드를 작성할 수 있다고 합니다.

MaxText는 연구와 프로덕션 모두에서 활용할 수 있는 야심찬 LLM 프로젝트의 출발점이 되는 것을 목표로 하고 있습니다. 사용자들은 바로 MaxText를 실험해 보고, 필요에 따라 포크해서 수정하는 것이 권장됩니다.

Runtime Performance Results: TPU v5p

Google은 MaxText로 int8에서의 고성능 학습을 시연하고, 51K개 칩 규모로 학습을 확장한 바 있습니다.

유사 프로젝트와의 비교

MaxText는 PyTorch로 작성된 Nvidia GPU용 독립형 GPT 구현체인 MinGPT/NanoGPT에서 많은 영감을 받았습니다. MaxText는 더 많은 모델을 지원하고 대규모로 확장 가능하다는 점에서 더 복잡합니다. 최종적으로는 NanoGPT에서 언급된 17% MFU보다 3배 이상 높은 MFU를 달성하고, 효율적인 autoregressive 디코딩을 위한 key-value cache를 구현했습니다.

Nvidia Megatron-LM은 Nvidia GPU를 대상으로 잘 튜닝된 LLM 구현체인데, MaxText와 유사한 면이 있습니다. 두 구현체 모두 비슷한 MFU를 달성하지만, 프로그래밍 전략에서 차이를 보입니다. MaxText는 순수 Python으로 작성되어 XLA 컴파일러에 크게 의존하는 반면, Megatron-LM은 Python과 CUDA를 혼합해서 최적화된 CUDA 커널을 활용합니다.

한편 Jax 기반의 LLM 구현체인 Pax와도 비교해볼 수 있습니다. Pax는 강력한 configuration을 제공해 개발자들이 config 파라미터만 수정하면 모델을 변경할 수 있도록 하는데 중점을 둡니다. 반면 MaxText는 다양한 LLM의 간단하고 구체적인 구현체로, 사용자가 직접 소스 코드를 수정하면서 확장해나가는 것을 장려합니다.

주요 특징

  • TPU와 GPU 모두 지원 (GPU는 프리뷰)
  • 학습과 추론 모두 지원
  • 지원 모델: Llama2, Mistral, Gemma
  • 높은 MFU(Model Flops Utilization) 달성 - 높은 MFU는 하드웨어 자원을 효율적으로 활용하고 있다는 뜻
  • 소규모부터 대규모 클러스터까지 손쉬운 확장
  • 간단하고 최적화가 잘된 코드 (Jax와 XLA 컴파일러 덕분)

사용 방법

MaxText를 사용하기 위해서는 먼저 First Run 가이드를 따라 환경 설정을 해주어야 합니다. 그 후에는 간단한 설정 파일과 실행 명령어로 모델 학습과 추론을 진행할 수 있습니다.

모델 학습

모델 학습을 위해서는 우선 설정 파일(예: configs/base.yml)을 준비합니다.

# MaxText/configs/base.yml
include_in_name:
- global_parameter_scale
- per_device_batch_size
- learning_rate

model:
  name: GPTLikeModel
  global_parameter_scale: 16  # 16B model
  per_device_batch_size: 4
  max_sequence_length: 2048

train:
  dataset_path: gs://my-dataset-bucket
  steps: 10000
  learning_rate: 1e-3

그 다음 아래 명령어로 학습을 실행합니다.

export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example \
base_output_directory=gs://my-output-bucket

위 명령어는 16B 모델을 10000 스텝 동안 학습하는 예시입니다. 출력 결과는 gs://my-output-bucket에 저장됩니다.

모델 추론

학습된 체크포인트를 사용해 추론을 실행하는 예시는 다음과 같습니다.

export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/decode.py MaxText/configs/base.yml decode_from=pretrained \
pretrained_checkpoint_format=gs://my-output-bucket/example/checkpoint-{checkpoint}.pkl

위 명령어는 gs://my-output-bucket/example/ 경로에 저장된 체크포인트를 사용해 추론을 실행합니다.

추론 설정은 configs/base.yml에서 아래와 같이 지정할 수 있습니다.

# MaxText/configs/base.yml 
...(생략)...

decode:
  repetition_penalty: 1.0
  temperature: 0.7
  top_k: 40
  top_p: 0.9
  max_decode_steps: 128

위와 같이 repetition penalty, temperature, top-k, top-p 샘플링 등 다양한 디코딩 옵션을 조정할 수 있습니다.

이처럼 MaxText는 간단한 설정과 실행 명령어로 대규모 언어 모델의 학습과 추론을 쉽게 진행할 수 있도록 도와줍니다. 더 자세한 내용은 MaxText Github 저장소를 참고해 주세요.

라이선스

MaxText는 Apache 2.0 라이선스로 공개 및 배포되고 있습니다.

더 읽어보기

MaxText GitHub 저장소

Google의 Jax와 TPU에 대한 정보




이 글은 GPT 모델로 정리한 글을 바탕으로 한 것으로, 원문의 내용 또는 의도와 다르게 정리된 내용이 있을 수 있습니다. 관심있는 내용이시라면 원문도 함께 참고해주세요! 읽으시면서 어색하거나 잘못된 내용을 발견하시면 덧글로 알려주시기를 부탁드립니다. :hugs:

:pytorch:파이토치 한국 사용자 모임:kr:이 정리한 이 글이 유용하셨나요? 회원으로 가입하시면 주요 글들을 이메일:love_letter:로 보내드립니다! (기본은 Weekly지만 Daily로 변경도 가능합니다.)

:gift: 아래:arrow_lower_right:쪽에 좋아요:heart:를 눌러주시면 새로운 소식들을 정리하고 공유하는데 힘이 됩니다~ :star_struck: