Torch Memory Saver 소개
Torch Memory Saver는 PyTorch에서 GPU 메모리를 효율적으로 관리하기 위해 개발된 경량 유틸리티 라이브러리입니다. GPU 메모리 부족은 대규모 모델 학습 및 추론 시 흔히 발생하는 문제로, 불필요한 텐서가 오랫동안 메모리를 점유하면 자원 활용도가 떨어집니다. Torch Memory Saver는 특정 시점에 텐서의 물리적 메모리를 해제하고 필요할 때 다시 복구할 수 있는 기능을 제공합니다.
이 과정에서 중요한 점은 물리적 메모리를 해제하더라도 가상 주소는 그대로 유지된다는 것입니다. 따라서 CUDA Graph와 같은 구조에서도 오류 없이 재할당이 가능하며, 강화학습(RL) 환경처럼 반복적인 메모리 사용이 많은 경우에도 효과적으로 동작합니다. 기존의 Gradient Checkpointing이 계산량을 늘리는 방식이라면, Torch Memory Saver는 단순히 메모리 점유 상태를 조절하는 방식이어서 보다 직관적입니다.
특히 개발자가 직접 CUDA 메모리 관리 방식을 실험할 수 있도록 LD_PRELOAD 기반의 빠른 해킹(hacky experiment)도 지원합니다. 이를 통해 cudaMalloc 호출을 가로채어 메모리 해제를 직접 구현할 수 있으며, GPU 메모리 사용량을 눈에 띄게 줄일 수 있음을 확인했습니다.
Gradient Checkpointing과의 차이
Gradient Checkpointing은 역전파 과정에서 중간 값을 저장하지 않고 필요할 때 다시 계산하는 방식으로 메모리를 절약합니다. 반면 Torch Memory Saver는 이미 계산된 텐서를 GPU 메모리에서 해제하고, 필요 시 같은 주소로 다시 복구하는 방식입니다. 따라서 계산량이 늘지 않으며, CUDA Graph 같은 최적화된 실행 환경에서도 안전하게 사용할 수 있다는 장점이 있습니다.
DeepSpeed, Megatron-LM 같은 프레임워크가 전체 학습 파이프라인 최적화에 집중하는 반면, Torch Memory Saver는 작은 유틸리티 형태로 특정 텐서의 메모리만 제어할 수 있다는 점에서 훨씬 가볍고 단순합니다.
Torch Memory Saver의 주요 기능
기본적인 사용 예시
기본적인 Torch Memory Saver 사용법응ㄴ 다음과 같습니다:
with torch_memory_saver.region():
pauseable_tensor = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
torch_memory_saver.pause() # 메모리 해제
torch_memory_saver.resume() # 메모리 복구
pause() 호출 시 물리 메모리는 해제되지만 가상 주소는 유지되고, resume() 호출 시 같은 주소로 다시 매핑됩니다.
태그 기반 제어
여러 텐서를 태그별로 관리할 수 있어 선택적으로 해제 및 복구가 가능합니다:
with torch_memory_saver.region(tag="type1"):
tensor1 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
with torch_memory_saver.region(tag="type2"):
tensor2 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
torch_memory_saver.pause("type1")
torch_memory_saver.resume("type2")
CUDA Graph 지원
torch_memory_saver.cuda_graph(...) API를 사용하면 CUDA Graph 실행 중에도 텐서 메모리를 해제할 수 있습니다. CUDA Graph는 가상 주소가 변하지 않는 한 정상적으로 동작하므로, Torch Memory Saver의 해제·복구 메커니즘과 잘 호환됩니다.
CPU 백업 옵션
pause() 시 데이터를 버리지 않고 CPU에 보관할 수도 있습니다.
with torch_memory_saver.region(enable_cpu_backup=True):
tensor = torch.full((5_000_000_000,), 42, dtype=torch.uint8, device='cuda')
torch_memory_saver.pause()
torch_memory_saver.resume()
assert tensor == 42
Hook 모드
preload:LD_PRELOAD를 사용하여 CUDA malloc/free API를 가로채는 방식torch: PyTorch의 custom allocator API의 동작을 지정하는 방식
Hook 모드는 다음과 같이 사용 가능합니다:
torch_memory_saver.hook_mode = "torch"
Hacky Experiment: LD_PRELOAD 기반 설계
개발자는 Torch Memory Saver의 아이디어를 검증하기 위해 LD_PRELOAD를 사용한 실험을 진행했습니다. 이 방식은 cudaMalloc 호출을 가로채어 CUDA 드라이버 API(cuMemCreate, cuMemMap)로 대체하는 것입니다.
- 메모리 해제(hack_release_occupation): cuMemUnmap과 cuMemRelease를 사용해 물리 메모리를 해제하지만 가상 주소는 유지
- 메모리 복구(hack_resume_occupation): 기존 가상 주소에 cuMemCreate+cuMemMap으로 다시 매핑
이 실험은 모든 물리 메모리를 해제하는 단순한 방식이지만, CUDA Graph는 여전히 정상 동작했으며 nvidia-smi로 확인한 메모리 사용량도 확실히 줄었습니다. 다만 실제 설계에서는 개별 할당 단위별로 어떤 메모리를 해제할지 세밀하게 추적할 수 있도록 해야 합니다.
빌드 및 실행 방법은 다음과 같습니다:
# 빌드
g++ my_preload.cc -o my_preload.so -shared -fPIC -lcuda -I/usr/local/cuda/include
# 실행
LD_PRELOAD=./my_preload.so python3 example.py
라이선스
Torch Memory Saver 프로젝트는 MIT 라이선스로 공개 및 배포되고 있습니다. 상업적 사용을 포함해 자유롭게 활용할 수 있습니다.
Torch Memory Saver GitHub 저장소
더 읽어보기
이 글은 GPT 모델로 정리한 글을 바탕으로 한 것으로, 원문의 내용 또는 의도와 다르게 정리된 내용이 있을 수 있습니다. 관심있는 내용이시라면 원문도 함께 참고해주세요! 읽으시면서 어색하거나 잘못된 내용을 발견하시면 덧글로 알려주시기를 부탁드립니다. ![]()
파이토치 한국 사용자 모임
이 정리한 이 글이 유용하셨나요? 회원으로 가입하시면 주요 글들을 이메일
로 보내드립니다! (기본은 Weekly지만 Daily로 변경도 가능합니다.)
아래
쪽에 좋아요
를 눌러주시면 새로운 소식들을 정리하고 공유하는데 힘이 됩니다~ ![]()
