JAX란 무엇인가요?

JAX란 무엇인가요?
JAX는 고성능 머신 러닝 연구를 위해 결합된 Autograd와 XLA입니다.

업데이트된 버전의 Autograd를 통해 JAX는 네이티브 Python과 NumPy 함수를 자동으로 구분할 수 있습니다. 루프, 브랜치, 재귀, 클로저를 통해 차별화할 수 있으며, 파생 함수의 파생 함수를 취할 수 있습니다. 정방향 미분뿐만 아니라 그라데이션을 통한 역방향 미분(일명 역전파)도 지원하며, 이 두 가지를 임의의 순서로 구성할 수 있습니다.

새로운 기능은 JAX가 XLA를 사용하여 GPU와 TPU에서 NumPy 프로그램을 컴파일하고 실행한다는 것입니다. 컴파일은 기본적으로 내부에서 이루어지며 라이브러리 호출은 적시에 컴파일되고 실행됩니다. 그러나 JAX를 사용하면 단일 함수 API인 jit를 사용하여 자체 Python 함수를 XLA에 최적화된 커널로 적시에 컴파일할 수도 있습니다. 컴파일과 자동 차별화를 임의로 구성할 수 있으므로 Python을 벗어나지 않고도 정교한 알고리즘을 표현하고 최대 성능을 얻을 수 있습니다. pmap을 사용하면 여러 개의 GPU나 TPU 코어를 한 번에 프로그래밍하고 전체를 통해 차별화할 수도 있습니다.

조금 더 자세히 살펴보면 JAX가 실제로 컴포저블 함수 변환을 위한 확장 가능한 시스템이라는 것을 알 수 있습니다. grad와 jit는 모두 이러한 변환의 인스턴스입니다. 그 외에도 자동 벡터화를 위한 vmap과 여러 가속기의 단일 프로그램 다중 데이터(SPMD) 병렬 프로그래밍을 위한 pmap이 있으며, 앞으로 더 많은 기능이 추가될 예정입니다.

https://github.com/google/jax