Mutable.ai logo
 Auto Wiki by Mutable.ai
Create your own wiki
AI-generated instantly
Updates automatically
Solo and team plans

jax

Auto-generated from google/jax by Mutable.ai Auto Wiki
jax
GitHub Repository
Developergoogle
Written inPython
Stars26k
Watchers322
Created10/25/2018
Last updated01/05/2024
LicenseApache License 2.0
Homepagejax.readthedocs.io
Repositorygoogle/jax
Auto Wiki
Revision
Software Version0.0.4Basic
Generated fromCommit b8098b
Generated at01/05/2024

JAX is a Python library that provides primitives for automatic differentiation, JIT compilation, vectorization, parallelization, and other transformations to accelerate and simplify machine learning research.

The key functionality of JAX centers around transformations that allow compiling and optimizing Python and NumPy code to run efficiently on accelerators like GPUs and TPUs. JAX transforms pure Python+NumPy functions into efficient low-level code by tracing computations and applying performance optimizations using just-in-time XLA compilation. Rather than explicitly coding algorithms for accelerators, JAX allows users to write numerical code in Python and apply transformations like jax.jit and jax.grad to get optimized device-specific implementations.

Some of the most important transformations and capabilities provided by JAX include:

  • Automatic differentiation with jax.grad() to easily take derivatives of functions. JAX uses operator overloading and traces code to build a representation of the computation graph which is used to compute gradients via reverse-mode automatic differentiation. This allows machine learning models to be trained by taking gradients of loss functions.

  • Just-in-time compilation with jax.jit() to compile and optimize Python functions into XLA (Accelerated Linear Algebra) computations that run efficiently on hardware like GPUs and TPUs. JAX traces and compiles pure Python+NumPy functions into optimized low-level code using XLA.

  • Vectorization with jax.vmap() to automatically vectorize functions over array axes for batching/parallelism. For example, a function mapping from vectors to scalars can be lifted to operate over matrices in a data-parallel fashion without any code changes.

  • Multi-device parallelization with jax.pmap() to automatically parallelize functions across multiple devices like GPUs and TPU cores. This implements data parallelism by sharding inputs across devices, running the function in parallel, and gathering outputs.

The core of JAX relies on representing numerical programs via JAX's tracing machinery and applying transformations that lower computations into efficient XLA code. Key aspects like operator overloading, shape polymorphism, and representing programs as Python+NumPy allow JAX to accelerate pure numerical code while retaining Python's expressiveness. JAX also provides libraries with NumPy-like functionality and wrappers for other scientific Python packages to simplify porting code.

The main design choices are leveraging tracing and XLA compilation to generate optimized device code while keeping a simple Python interface, providing transformations like autodiff/vectorization/parallelization that work by manipulating trace representations, and extensive operator overloading to accelerate idiomatic numerical Python code with minimal annotations needed from users.

Array Operations and Math Utilities

This section covers the core array manipulation, mathematical, and linear algebra functionality provided by JAX. It allows performing NumPy-like operations on JAX arrays while leveraging XLA for efficient computation.

Read more

Core Array Operations

References: jax

The core array manipulation and shape utilities in JAX are implemented in the …/abstract_arrays.py file.

Read more

Linear Algebra

References: jax

Matrix decompositions, solvers, eigenvalues, norms, and products are handled in JAX through the …/linalg.py file. This file imports linear algebra functions from …/linalg.py, including:

Read more

Fourier Transforms

References: jax

Fourier transforms are handled in the …/ducc_fft.cc and …/ducc_fft_kernels.cc files. These files contain implementations of FFT kernels using the DuccFFT library.

Read more

Parallel Operations

References: jax

Collectives like sums, gathers over mapped axes are handled by the PmapExecutable class defined in …/pxla.py. PmapExecutable represents an XLA computation that has been parallelized for execution across multiple devices. It contains metadata like the logical axes, input/output sharding, and the underlying XLA computation.

Read more

Control Flow Primitives

JAX's control flow primitives allow expressing loops, conditionals, and custom solvers in a functional style. They leverage JAX's ability to translate Python control flow into efficient linearized computations by staging functions as JAXPRs.

Read more

Neural Networks

References: jax/nn

Neural network functionality in JAX is provided via reusable classes for layers, initializers, and activation functions.

Read more

Automatic Differentiation

References: jax

Automatic differentiation in JAX is implemented using the grad() and vjp() functions.

Read more

Compilation and Parallelism

The PmapExecutable class is the primary way of representing parallel XLA computations that have been lowered from JAX. It contains the underlying XLA Computation, logical axis names, and metadata like operation shardings. The parallel_callable function compiles a callable for parallel execution by lowering it to a PmapExecutable. It handles sharding arguments and collecting results across devices.

Read more

Compilation

JIT compilation of functions in JAX is implemented using the apply_primitive function in …/xla.py. This function is responsible for lowering each JAX primitive into XLA during execution. It dispatches based on the primitive name and arguments, which drives how JAX computations are staged into XLA.

Read more

Parallelism

The core functionality for parallelizing JAX computations across devices is contained in the …/pxla.py file. The PmapExecutable class represents a parallel mapping computation that has been lowered to XLA HLO. It contains the XLA Computation, local/global axis names, and metadata like op shardings. The parallel_callable function compiles a callable for parallel execution by lowering to a PmapExecutable. It handles sharding arguments and collecting results.

Read more

Vectorization

Vectorization in JAX is implemented using the vmap function defined in …/batching.py. vmap handles vectorizing functions over array axes by applying batching rules.

Read more

Scientific Computing

References: jax/scipy

JAX's SciPy library provides optimized implementations of functionality from SciPy, scikit-learn, and other scientific Python packages. This allows leveraging JAX transformations like autodiff and JIT compilation within scientific computing libraries.

Read more

SciPy Functionality

References: jax/scipy

Specialized SciPy functionality like linear algebra, FFTs, clustering, and other algorithms have been ported to work with JAX, enabling these techniques to take advantage of autodiff, GPUs, and TPUs. This includes:

Read more

Additional Scientific Libraries

References: jax

Other scientific Python libraries now compatible with JAX provide specialized functionality for scientific computing tasks beyond the core SciPy implementations. These libraries expose JAX-optimized versions of algorithms from packages like TensorFlow Probability and PyMC. They provide a unified API on top of JAX while leveraging lower-level primitives.

Read more

Debugging and Profiling

References: jax/tools

JAX provides a number of utilities for debugging and profiling code. At a high level:

Read more

Debugging Utilities

References: jax/tools

This code provides several utilities for debugging JAX programs and inspecting intermediate representations.

Read more

Profiling and Tracing

References: jax/tools

The profile function instruments a JAX function to collect profiling information when run. It returns a Profile object containing timing and memory usage stats. This allows developers to identify bottlenecks and optimize performance.

Read more

Gradient Checking

References: tests

Functions for numerically checking gradients compute the gradient of functions using forward and reverse mode automatic differentiation, and compare the results to verify the correctness of the gradients. This ensures gradients are computed accurately before using them for optimization.

Read more

Extensions and Customization

References: jax/extend

The …/mlir subdirectory allows lowering JAX computations to the MLIR intermediate representation and transforming them using MLIR passes. This enables optimizing JAX programs using the MLIR compiler infrastructure.

Read more

Extending JAX

References: jax/extend

JAX provides several mechanisms for extending its functionality in a controlled way. The …/extend directory contains utilities that expose lower-level JAX primitives and internals.

Read more

MLIR Integration

JAX integrates MLIR into its compiler pipeline to enable optimization of JAX programs via the MLIR intermediate representation. The key functionality is lowering JAX computations to MLIR IR and transforming the IR using MLIR passes.

Read more

Adding Custom Primitives

Exposing additional primitives from JAX's internals involves extending JAX's core library to expose new primitives for use in transformations. This allows users to define custom operators that JAX can differentiate, compile, and optimize.

Read more

Testing

References: tests

Comprehensive test suites for JAX functionality are implemented in several ways. The main approach is defining test case classes that inherit from jtu.JaxTestCase. These classes contain test methods that exercise different aspects of JAX functionality.

Read more

Test Suites

References: tests

The LaxAutodiffTest class contains many test methods that test the gradients of LAX primitives. These tests are parameterized over different LAX operations, argument shapes/types, and test configurations.

Read more

FileCheck Tests

References: tests/filecheck

Tests for lowering JAX computations into MLIR intermediate representation are implemented in several files in the …/filecheck directory. These tests use FileCheck to validate that the MLIR output from lowering JAX matches expectations.

Read more

Third Party Tests

References: tests/third_party

Compatibility tests vs. SciPy, NumPy, etc. focus on thoroughly testing that JAX implementations of functionality from third party libraries produce the same results as the original implementations. This ensures JAX can be a drop-in replacement.

Read more

Pallas Tests

References: tests/pallas

The …/pallas directory contains comprehensive tests for the Pallas compiler. The PallasTest base test case handles GPU/Triton configuration and caching of compiler outputs.

Read more

Building and Packaging

References: build

Scripts and utilities for building JAX from source and packaging it handle the following tasks:

Read more

Building JAX

References: build/build.py, build

Building JAX involves building the libjax dependency from source using Bazel. The main script for this is …/build.py, which handles downloading Bazel, setting environment variables, writing the Bazel configuration file, and executing the Bazel build command.

Read more

Testing JAX

The build directory contains scripts for running JAX's comprehensive test suites and CI. These scripts coordinate parallel execution of tests across multiple accelerators like GPUs and TPUs.

Read more

Building Documentation

References: jax

Scripts for building JAX's documentation are handled by utilities in the build directory. The main script is …/build.py, which contains logic for building the JAX library from source code. It handles downloading dependencies, configuring builds, and executing test commands.

Read more

ROCm Support

References: build/rocm

The …/rocm directory contains scripts and utilities for building and testing JAX with support for AMD ROCm GPUs. The main scripts are:

Read more

Docker Container Build

The …/ci_build.sh script handles building a Docker container with the necessary ROCm dependencies for compiling and testing JAX packages. The script parses command line arguments to determine the Dockerfile path, Python version, and other build options. It then sets variables like WORKSPACE, BUILD_TAG, and DOCKER_IMG_NAME used throughout the build.

Read more

ROCm Build Script

The …/build_rocm.sh script is the main way to build JAX with support for AMD's ROCm GPUs. It handles configuring the environment and dependencies needed to build JAX with ROCm, and then runs the JAX build process with the proper ROCm-specific options.

Read more

Multi-GPU Testing

This script allows JAX tests that require multiple GPUs to run on AMD/ATI GPU systems with different numbers of GPUs. It detects the number of AMD/ATI GPUs present using lspci and then sets the HIP_VISIBLE_DEVICES environment variable accordingly to specify which GPU devices should be used for the tests. For systems with 8 or more GPUs, all 8 GPUs will be used. For 4 to 7 GPUs, the first 4 GPUs will be used, and so on. It then runs the pmap_test.py and multi_device_test.py test modules to execute the relevant tests using the specified GPUs.

Read more

Documentation

References: docs

The docs directory contains documentation files that provide tutorials, guides, design proposals, and reference material for JAX.

Read more

Tutorials

References: docs/tutorials

The …/tutorials directory contains in-depth tutorials demonstrating key capabilities of JAX. The tutorials cover:

Read more

Notebooks

References: docs/notebooks

Runnable Jupyter notebooks demonstrate key JAX techniques for machine learning research. The notebooks illustrate concepts through examples and exercises, allowing readers to experiment with JAX.

Read more

Design Proposals

References: docs/jep

JEPs propose major new features for JAX, improvements to performance, correctness, usability, and more. Many introduce new primitives, transformations, or abstractions. Some rework existing functionality.

Read more

Documentation Utilities

The …/sphinxext directory contains custom Sphinx extensions that add roles for linking to JAX GitHub issues and DOI URLs from documentation. The jax_extensions.py file defines the jax_issue_role and doi_role functions, which generate hyperlinks to JAX issues and DOI URLs respectively. The setup function registers these roles with Sphinx, making the "jax-issue" and "doi" roles available for use in documentation markup.

Read more

JAX 101

References: docs/jax-101

The core functionality covered under JAX 101 introduces the basics of JAX including JAX NumPy, automatic differentiation with jax.grad(), just-in-time compilation with jax.jit, vectorization with jax.vmap, and more advanced topics like random number generation, pytrees, parallelism, and state handling.

Read more

Debugging Guide

References: docs/debugging

JAX provides several tools and techniques for debugging programs.

Read more

Pallas Documentation

References: docs/pallas

Pallas allows writing GPU/TPU kernels by extending JAX with Ref types, new primitives, and pallas_call. Kernels use Refs instead of JAX arrays, loading from and storing to memory via pallas.load and pallas.store. pallas_call maps the kernel over an iteration space specified by grid and Specs.

Read more