Flax.linen은 JAX 기반의 신경망 라이브러리지만, 애석하게도 jax.lax에 있는 필수 하위기능을 사용할 수 없습니다(?). 자체적인 graph compilation 로직때문에 이런 것 같긴 한데, 상당히 불편합니다. 가령, jax/flax에서는 random에 의사난수 키가 강제되기 때문에, split()하는 과정이 있어야 합니다. random을 사용하는 rnn 같은 경우에 키가 계속 변경되기 때문에 코드를 짜는 입장에서는 모듈안에 split()을 불러서 나누는 게 편하지만, Flax.linen 모듈과 scan()(for-loop의 함수형 프로그래밍 버전 - for loop의 컴파일 속도가 jax 기반 라이브러리들에서는 극악이기 때문에 이걸 무조건 써야합니다)은 이걸 허락하지 않습니다. 이 것 덕분에 고생을 좀 세게 하고 Equinox로 갈아탔습니다. Equinox는 jax에서 허용되는 모든 동작들이 어느 부분에서든 자유롭게 허락됩니다. 특히, Equinox의 모듈 자체가 그냥 jax의 pytree 데이터구조여서, flax.linen과 비교했을 때 안되는 게 거의 없습니다.
JAX multi-backend에 관해
한줄 요약하자면, JAX multi-backend는 PyTorch 밑 레이어로 굴리긴 어려운 구조입니다. JAX의 multi-backend는 XLA(XLA: 머신러닝을 위한 컴파일러 최적화 | TensorFlow)를 기반으로 돌아갑니다. XLA는 TF2때 처음 소개되었는데, 다이나믹 그래프와 궁합이 잘 안 맞아서(모든 동작을 XLA 기반 동작으로 옮기기 어렵기 때문입니다 - 가령 RNN 알고리즘들은 단순 그래프 레벨이 아닌 cuda 스케줄링 레벨에서 최적화되어 있습니다) JAX를 개발한 것이기에 Torch를 XLA와 같이 쓰기엔 전통적인 사용자 입장에서는 당황스러울 수 있을 것 같습니다.
JAX는 아직 부족한 게 많습니다. 대규모의 인프라로 확장가능하다는 글의 요지에 대해는 동의하는 바이지만, 사실 dataloader와 같은 필수적이면서도 소소한 기능들에 대한 지원이 매우 많이 배제되어있습니다(multi-backend를 위한 필수적 결정이라고도 생각합니다). 저 또한 미래 컴퓨팅의 형태가 극단적인 병렬컴퓨팅 구현(확장성)이라고 생각하지만, 대부분 집이나 소규모 연구실에서 하는 딥러닝으로는 nvidia gpu 8장 이상 이용하는 것도 쉽지 않은 일이기에 파이토치와 비교시 오히려 파이토치가 비교우위에 있을 것이라 지금은 생각합니다. 다만, 클라우드를 주력으로 활용한다면 JAX는 해당 클라우드의 자체적인 가속기등을 활용해 소모되는 경비를 절감할 수 있게 한다는 점에서 효과적일 것이라 생각합니다.