Needle: Gemini 3.1을 증류해 만든 26M 파라미터 함수 호출 전용 온디바이스 모델 (feat. Cactus)

Needle 소개

Needle은 Cactus Compute가 공개한, 함수 호출(Function Calling) 한 가지 작업에 특화된 26M 파라미터 규모의 초소형 언어 모델입니다. Gemini 3.1 Flash Lite를 교사 모델로 삼아 단순화된 어텐션 구조로 증류(Distillation)했으며, 사용자 질의(Query)에 대해 후보 도구(Tools) 중 어느 함수를 호출하고 어떤 인자를 넣어야 하는지를 JSON으로 출력하는 데 초점을 맞추고 있습니다. 모델 크기가 26M에 불과하기 때문에 일반 PC와 Mac 환경에서도 추론은 물론 사용자 데이터로의 파인튜닝까지 로컬에서 진행할 수 있고, 운영 환경에서는 Cactus 런타임 위에서 프리필(Prefill) 6,000 토큰/초, 디코드(Decode) 1,200 토큰/초 수준으로 동작합니다.

Needle은 단순히 모델 가중치만 공개된 프로젝트가 아니라, 26M이라는 극단적으로 작은 규모로 함수 호출 능력을 어떻게 끌어올렸는지에 대한 설계 결정을 함께 공개한 실험 보고서에 가깝습니다. 저자들은 핵심 아이디어를 "Simple Attention Network(SAN)" 로 정리했으며, 이 구조에서는 전통적인 Transformer 블록에서 파라미터의 약 2/3를 차지하는 피드포워드(Feed-Forward Network, FFN) 레이어를 완전히 제거하고, 대신 인코더-디코더 구조와 크로스 어텐션(Cross Attention)에 자원을 집중합니다. 함수 호출이 본질적으로 "질의 → 도구 목록 → 인자 추출 → JSON 조립"이라는 정렬·복사 중심 작업이라는 관찰에서 출발한 결정입니다.

가중치와 데이터 생성 코드는 모두 공개되어 있으며, Hugging Face의 Cactus-Compute/needle 리포지토리에서 다운로드할 수 있습니다. 라이선스는 MIT이며, 단일 샷(Single-shot) 함수 호출 벤치마크에서 FunctionGemma-270M, Qwen-0.6B, Granite-350M, LFM2.5-350M 같은 더 큰 모델보다 우수한 결과를 보였다고 보고됩니다. 다만 저자들도 강조하듯 26M이라는 규모는 대화형 시나리오에서는 여전히 한계가 있고, 특정 도구 세트에 맞춰 파인튜닝하는 것을 전제로 한 모델입니다. 동봉된 웹 UI(needle playground)에서 자신의 도구 정의로 데이터셋을 합성하고 파인튜닝까지 진행하도록 설계된 것도 이 때문입니다.

Needle의 Simple Attention Network 구조

Needle의 모델은 임베딩 차원 d=512, 어텐션 헤드 $8$개(GQA의 KV 헤드는 $4$개), BPE 토크나이저 어휘 크기 $8{,}192$를 사용합니다. 구조는 인코더 $12$층, 디코더 $8$층의 인코더-디코더이며, 디코더에서는 마스킹된 셀프 어텐션과 인코더 출력에 대한 크로스 어텐션을 모두 사용합니다. README에 포함된 아키텍처 다이어그램은 이 구조를 다음과 같이 보여줍니다.

d=512, 8H/4KV, BPE=8192
                                  ┌──────────────┐
                                  │  Tool Call   │
                                  └──────┬───────┘
                                        ┌┴──────────┐
                                        │  Softmax  │
                                        └─────┬─────┘
                                        ┌─────┴─────┐
                                        │ Linear (T)│  ← tied
                                        └─────┬─────┘
                                        ┌─────┴─────┐
                                        │ ZCRMSNorm │
                                        └─────┬─────┘
                                     ┌────────┴────────┐
                                     │ Decoder x 8     │
                                     │┌───────────────┐│
                                     ││ ZCRMSNorm     ││
                                     ││ Masked Self   ││
                                     ││ Attn + RoPE   ││
                                     ││ Gated Residual││
                                     │├───────────────┤│
  ┌──────────────┐                   ││ ZCRMSNorm     ││
  │ Encoder x 12 │──────────────────────▶Cross Attn   ││
  │              │                   ││ Gated Residual││
  │ ┌──────────┐ │                   │└───────────────┘│
  │ │ZCRMSNorm │ │                   └────────┬────────┘
  │ │Self Attn │ │                      ┌─────┴─────┐
  │ │ GQA+RoPE │ │                      │ Embedding │  ← shared
  │ │Gated Res │ │                      └─────┬─────┘
  │ │          │ │                    ┌───────┴───────-┐
  │ │ (no FFN) │ │                    │[EOS]<tool_call>│
  │ └──────────┘ │                    │ + answer       │
  │              │                    └───────────────-┘
  └──────┬───────┘
         │
    ┌────┴──────┐
    │ Embedding │
    └────┬──────┘
         │
    ┌────┴──────┐
    │   Text    │
    │  query    │
    └───────────┘

Needle의 설계에서 가장 눈에 띄는 결정은 다음 네 가지입니다.

FFN 제거: 표준 Transformer 블록에서 약 2/3 파라미터를 차지하는 FFN을 모두 빼고, 어텐션과 크로스 어텐션 비중을 늘립니다. 함수 호출은 위치별 비선형 특징 변환보다 정렬·복사 중심 작업이기 때문에, 작은 모델에서는 FFN 파라미터의 효용이 떨어진다는 관찰에 기반합니다.

Gated Residual: 표준 잔차(Residual) x = x + \text{Attn}(\text{Norm}(x)) 대신, 학습 가능한 게이트를 사용해 x = x + \sigma(g) \cdot \text{Attn}(\text{Norm}(x)) 형태로 잔차의 강도를 층마다 학습합니다. 초기화 시 $g = 0$이므로 \sigma(0) = 0.5, 즉 반강도 잔차에서 시작합니다.

ZCRMSNorm: 표준 RMSNorm x \cdot \gamma / \text{RMS}(x) 에서 \gamma 초기값을 0 으로 두는 변형으로, 초기에는 정규화가 거의 항등 함수(identity)에 가깝게 동작합니다. 게이트 잔차와 짝지어 학습 초기의 강한 학습된 편향을 줄이는 역할을 합니다.

INT4 QAT를 정규화로 사용: 학습 100 스텝마다 가중치를 그룹별 INT4 양자화(symmetric, group size 32)로 가짜 양자화(Fake Quantization)한 뒤 순방향 계산에 사용하고, 역전파는 Straight-Through Estimator(STE)로 흘립니다. 양자화 노이즈가 일종의 가중치 노이즈로 작용해 과적합을 완화하면서, 추론 시 사용할 양자화와 동일한 조건으로 학습하기 때문에 사후 양자화(Post-Training Quantization) 손실이 적습니다.

이 외에도 인코더 출력을 평균 풀링하여 CLIP 스타일의 대조 학습(Contrastive Loss)을 함께 적용해, 도구 후보가 많을 때 top-k 도구 검색이 가능하도록 한 점, Q/K/V/O 프로젝션에는 Muon, 나머지에는 AdamW의 듀얼 옵티마이저를 적용해 어텐션 전용 깊은 구조의 학습 안정성을 확보한 점 등 디테일이 함께 공개되어 있습니다. 사전학습은 TPU v6e 16 대로 200B 토큰을 약 27시간, 사후학습은 단일 샷 함수 호출 데이터셋 2B 토큰을 약 45분 동안 학습하는 비교적 가벼운 일정입니다.

Needle의 데이터 형식과 핵심 의사코드

Needle은 JSONL 한 줄당 query, tools, answers 세 필드를 받아 학습합니다. 도구 스키마와 정답 호출은 모두 JSON 문자열 형태입니다.

{"query": "What's the weather in Paris?", "tools": "[{\"name\":\"get_weather\",\"description\":\"Get current weather for a city.\",\"parameters\":{\"location\":{\"type\":\"string\",\"description\":\"City name.\",\"required\":true}}}]", "answers": "[{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}]"}

PyTorch 사용자에게 익숙한 형태로 Needle의 한 디코더 블록을 표현하면 대체로 다음 의사코드와 같습니다. 실제 학습 코드는 JAX 기반이지만, 구조 자체는 PyTorch에서도 그대로 재현할 수 있습니다.

# Simple Attention Network: FFN을 제거하고 게이트 잔차로 깊이를 유지
class GatedResidual(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gate = nn.Parameter(torch.zeros(1))  # sigmoid(0)=0.5에서 시작

    def forward(self, x, sublayer_out):
        return x + torch.sigmoid(self.gate) * sublayer_out


class NeedleDecoderBlock(nn.Module):
    def __init__(self, d=512, n_heads=8, n_kv=4):
        super().__init__()
        self.norm_self = ZCRMSNorm(d)       # gamma=0으로 초기화된 RMSNorm
        self.self_attn = MaskedSelfAttention(d, n_heads, n_kv, use_rope=True)
        self.res_self = GatedResidual(d)

        self.norm_cross = ZCRMSNorm(d)
        self.cross_attn = CrossAttention(d, n_heads, n_kv)
        self.res_cross = GatedResidual(d)
        # 의도적으로 FFN 없음 — 어텐션이 정렬·복사를 전담

    def forward(self, x, enc_out, enc_mask=None):
        x = self.res_self(x, self.self_attn(self.norm_self(x)))
        x = self.res_cross(x, self.cross_attn(self.norm_cross(x), enc_out, enc_mask))
        return x

토큰 단위 손실 가중치는 함수 호출 작업의 오류 분포에 맞춰 조정되어 있습니다. 모델은 학습 초반에 이미 약 99%의 JSON 파싱 성공률에 도달하기 때문에, 베이스 JSON 구조에는 $1.0$의 가중치를, 인자 값에는 4.0, 인자 키에는 1.5, 도구 이름에는 $2.0$의 가중치를 부여하여 실제 오류가 자주 발생하는 부분에 학습 신호를 집중시킵니다. 보조 손실로는 로짓 안정성을 위한 z-loss와 0.1배 가중의 CLIP 대조 손실이 함께 사용됩니다.

Needle 설치 및 사용법

저장소를 클론한 뒤 자체 셋업 스크립트로 환경을 구성하고, 웹 UI 플레이그라운드(needle playground)에서 자신의 도구로 추론과 파인튜닝을 시도해 볼 수 있습니다.

git clone https://github.com/cactus-compute/needle.git
cd needle && source ./setup
needle playground

Python에서는 다음과 같이 모델을 로드하고 함수 호출을 생성합니다.

from needle import SimpleAttentionNetwork, load_checkpoint, generate, get_tokenizer

# 사전학습된 체크포인트 로드 (가중치는 자동 다운로드)
params, config = load_checkpoint("checkpoints/needle.pkl")
model = SimpleAttentionNetwork(config)
tokenizer = get_tokenizer()

# 도구 정의와 함께 질의를 전달
result = generate(
    model, params, tokenizer,
    query="What's the weather in San Francisco?",
    tools='[{"name":"get_weather","description":"Get current weather for a city.","parameters":{"location":{"type":"string","description":"City name.","required":true}}}]',
    stream=False,
)
print(result)
# [{"name":"get_weather","arguments":{"location":"San Francisco"}}]

자신만의 도구 데이터셋에 대해 파인튜닝을 진행하려면 JSONL 파일을 만들고 다음 명령으로 학습할 수 있습니다. 저자들은 도구 하나당 최소 120개(학습 100 / 검증 10 / 테스트 10)의 예시를 권장합니다. 데이터가 너무 적으면 학습 지표는 완벽해 보여도 실제 일반화 성능이 떨어지기 쉽습니다.

# 플레이그라운드(데이터 합성 → 학습 → 평가까지 자동)
needle playground

# CLI 직접 파인튜닝 (가중치는 자동 다운로드)
needle finetune data.jsonl

# 파인튜닝된 체크포인트로 단일 추론
needle run \
  --checkpoint checkpoints/needle_finetuned_*_best.pkl \
  --query "What's the weather?" \
  --tools '[{"name":"get_weather","description":"Get current weather for a city.","parameters":{"location":{"type":"string","description":"City name.","required":true}}}]'

이 외에도 needle pretrain, needle train, needle eval, needle generate-data, needle tpu 같은 서브커맨드가 제공되어, 데이터 합성부터 사전학습/평가까지 동일한 CLI에서 다룰 수 있습니다.

Needle 라이선스

Cactus Needle 프로젝트는 MIT 라이선스로 공개되어 있어 개인 및 상업적 목적으로 자유롭게 사용·수정·배포할 수 있습니다. Hugging Face의 모델 가중치도 함께 공개되어 있어, 직접 추론 백엔드에 통합하거나 자신만의 도구 세트에 맞춰 재학습하는 용도로 활용할 수 있습니다.

:house: Cactus Compute 공식 홈페이지

:books: Needle의 Simple Attention Network 설계 문서

:github: Needle 프로젝트 GitHub 저장소

:hugs: Needle 모델 다운로드

더 읽어보기




이 글은 GPT 모델로 정리한 글을 바탕으로 한 것으로, 원문의 내용 또는 의도와 다르게 정리된 내용이 있을 수 있습니다. 관심있는 내용이시라면 원문도 함께 참고해주세요! 읽으시면서 어색하거나 잘못된 내용을 발견하시면 덧글로 알려주시기를 부탁드립니다. :hugs:

:pytorch:파이토치 한국 사용자 모임:south_korea:이 정리한 이 글이 유용하셨나요? 회원으로 가입하시면 주요 글들을 이메일:love_letter:로 보내드립니다! (기본은 Weekly지만 Daily로 변경도 가능합니다.)

:wrapped_gift: 아래:down_right_arrow:쪽에 좋아요:+1:를 눌러주시면 새로운 소식들을 정리하고 공유하는데 힘이 됩니다~ :star_struck:

1개의 좋아요