메두사: 여러 디코딩 헤더를 사용한 대규모 언어 모델 추론 가속화 프레임워크 (Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads)

PyTorchKR:fire::kr: :speech_balloon:

  • 최근 LLM 추론을 빠르게 하려는 다양한 시도들 중 하나를 발견하여 공유드립니다. :hugs:

  • 이 글은 GPT 모델로 자동 요약한 설명으로, 잘못된 내용이 있을 수 있으니 원문을 참고해주세요! :smile:
  • 읽으시면서 어색하거나 잘못된 내용을 발견하시면 덧글로 알려주시기를 부탁드립니다! :bowing_man:

메두사: 여러 디코딩 헤더를 사용한 대규모 언어 모델 추론 가속화 프레임워크 (Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads)

소개

메두사 데모: 여러 디코딩 헤더를 사용한 대규모 언어 모델 추론 가속화 프레임워크 (Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads)

대규모 언어 모델(LLM)의 비용을 통제하면서 텍스트 생성 속도를 높이는 것은 쉽지 않은 문제입니다. 이러한 텍스트 생성은 메모리 바운드가 주요한 문제로, 지연 시간의 대부분은 메모리 트랜잭션에서 발생하며 연산 그 자체에서 발생하지는 않습니다.

메두사는 한 번에 다음 번 토큰뿐만 아니라 여러번 이후의 토큰 후보들을 생성한 다음 이를 검증하는 방식으로 모델을 추가 학습합니다. 지연 시간을 줄일 수 있지만 시스템의 복잡성과 샘플링의 비효율성과 같은 문제들을 해결해야 합니다.

메두사 프레임워크

구조

기존 대규모 언어 모델(LLM)의 구조를 변경하지 않고 메두사(Medusa) 프레임워크를 적용할 수 있습니다. 메두사 프레임워크는 기존 모델은 그대로 두고, 새로운 요소들을 추가하기만 합니다. 예를 들어, 원래 모델에 '메두사 헤드'라고 불리는 추가 디코딩 헤드를 추가하여 한 번에 여러 토큰을 생성하는 방식입니다. 메두사 프레임워크의 구조는 아래와 같습니다:

메두사 헤드(Medusa Heads)

  • 이는 원래의 언어 모델 구조에 추가되는 새로운 디코딩 헤드입니다.
  • 각 메두사 헤드는 다음에 올 토큰들을 동시에 예측하는 기능을 합니다.
  • 헤드들은 원래 모델과 함께 훈련되며, 원래 모델은 훈련 중 고정됩니다​​.

트리-기반 어텐션(Tree-Based Attention)

  • 메두사 헤드가 생성한 여러 후보 토큰을 병렬로 처리하기 위해 사용됩니다.
  • 각 헤드에서 나온 상위 예측들의 카테시안 곱을 사용하여 후보 세트를 만듭니다.
  • 트리 구조의 주의 메커니즘을 사용하여, 토큰의 역사적 맥락을 유지하면서 동시에 여러 후보를 처리합니다​​.

전형적인 수용(Typical Acceptance)

  • 후보 중에서 가장 긴 타당한 접두사를 선택하기 위한 방법입니다.
  • 이는 기존 모델의 예측 확률에 기반하여 설정된 임계값을 사용합니다.
  • 샘플링 온도를 조정하여 모델의 창의성을 제어할 수 있으며, 이는 전통적인 중요도 샘플링보다 효율적입니다​​.

기존 LLM과의 통합

1. 기존 모델 유지

메두사 프레임워크를 적용할 때, 기존 언어 모델의 구조는 변경되지 않습니다. 기존 모델은 통합 과정에서 그대로 유지되며, 새로운 요소들이 추가되기만 합니다​​.

2. 메두사 헤드 추가

메두사 프레임워크의 핵심 요소인 메두사 헤드는 기존 모델에 추가되어, 여러 개의 미래 토큰을 동시에 예측합니다. 이러한 헤드들은 기존 모델과 함께 통합되어, 모델의 성능을 향상시킵니다​​.

3. 학습 과정

메두사 헤드의 학습은 기존 모델을 고정한 상태에서 진행됩니다. 즉, 기존 모델의 파라미터는 학습 과정에서 변경되지 않으며, 오직 새로운 헤드들만이 학습됩니다​​.

구현 및 결과

LLaMA 모델의 변형인 Vicuna 모델(7B, 13B, 33B)을 사용하여 메두사 프레임워크를 테스트하였습니다. ShareGPT 데이터셋을 사용하여 학습하였으며, 단일 GPU에서 몇 시간 가량의 학습 시간이 소요되었습니다. 이 때 8비트 양자화를 사용하여 메모리 요구 사항을 줄이고 효율성을 높였습니다. 이 때 약 2배 가량의 속도 향상을 보였습니다.


메두사 프레임워크의 일반적인 이득은 아래와 같습니다:

1. 속도 향상

메두사 프레임워크는 대규모 언어 모델의 텍스트 생성 속도를 약 2배 가까이 향상시킵니다. 이는 메두사 헤드와 트리 기반 주의 메커니즘을 통해 병렬 처리를 가능하게 함으로써 달성됩니다​​​​.

2. 효율적인 학습

메두사 헤드의 학습은 기존 모델을 고정한 상태에서 이루어지므로, 학습 과정이 매우 파라미터 효율적입니다. 이는 별도의 대규모 모델 학습에 비해 상대적으로 적은 자원을 요구합니다​​.

3. 단순화된 시스템 디자인

기존 모델에 메두사 헤드만 추가하므로, 시스템의 복잡성이 크게 증가하지 않습니다. 이는 분산 환경에서의 적용을 용이하게 합니다​​.

더 읽어보기

메두사 프로젝트 홈페이지

GitHub 저장소

영감을 받은 논문: 추측적 디코딩(speculative decoding)

영감을 받은 논문: 여러 디코딩 헤드(multiple decoding-head)