Mutable.ai logoAuto Wiki by Mutable.ai

jax

Auto-generated from google/jax by Mutable.ai Auto Wiki

jax
GitHub Repository
Developergoogle
Written inPython
Stars26k
Watchers 322
Created2018-10-25
Last updated2024-01-05
LicenseApache License 2.0
Homepagejax.readthedocs.io
Repositorygoogle/jax
Auto Wiki
Generated at2024-01-05
Generated fromCommit b8098b
Version0.0.4

JAX is a Python library that provides primitives for automatic differentiation, JIT compilation, vectorization, parallelization, and other transformations to accelerate and simplify machine learning research.

The key functionality of JAX centers around transformations that allow compiling and optimizing Python and NumPy code to run efficiently on accelerators like GPUs and TPUs. JAX transforms pure Python+NumPy functions into efficient low-level code by tracing computations and applying performance optimizations using just-in-time XLA compilation. Rather than explicitly coding algorithms for accelerators, JAX allows users to write numerical code in Python and apply transformations like jax.jit and jax.grad to get optimized device-specific implementations.

Some of the most important transformations and capabilities provided by JAX include:

  • Automatic differentiation with jax.grad() to easily take derivatives of functions. JAX uses operator overloading and traces code to build a representation of the computation graph which is used to compute gradients via reverse-mode automatic differentiation. This allows machine learning models to be trained by taking gradients of loss functions.

  • Just-in-time compilation with jax.jit() to compile and optimize Python functions into XLA (Accelerated Linear Algebra) computations that run efficiently on hardware like GPUs and TPUs. JAX traces and compiles pure Python+NumPy functions into optimized low-level code using XLA.

  • Vectorization with jax.vmap() to automatically vectorize functions over array axes for batching/parallelism. For example, a function mapping from vectors to scalars can be lifted to operate over matrices in a data-parallel fashion without any code changes.

  • Multi-device parallelization with jax.pmap() to automatically parallelize functions across multiple devices like GPUs and TPU cores. This implements data parallelism by sharding inputs across devices, running the function in parallel, and gathering outputs.

The core of JAX relies on representing numerical programs via JAX's tracing machinery and applying transformations that lower computations into efficient XLA code. Key aspects like operator overloading, shape polymorphism, and representing programs as Python+NumPy allow JAX to accelerate pure numerical code while retaining Python's expressiveness. JAX also provides libraries with NumPy-like functionality and wrappers for other scientific Python packages to simplify porting code.

The main design choices are leveraging tracing and XLA compilation to generate optimized device code while keeping a simple Python interface, providing transformations like autodiff/vectorization/parallelization that work by manipulating trace representations, and extensive operator overloading to accelerate idiomatic numerical Python code with minimal annotations needed from users.

Array Operations and Math Utilities

References: jax/_src/numpy, jax/numpy, jax/_src/lax

This section covers the core array manipulation, mathematical, and linear algebra functionality provided by JAX. It allows performing NumPy-like operations on JAX arrays while leveraging XLA for efficient computation.

…/numpy contains the main implementation. It defines the ShapedArray class representing multi-dimensional arrays. Functions like abs(), add(), arange() provide array creation and manipulation.

…/polynomial.py implements polynomial functions like roots(), polyfit(), polyval() using XLA. It handles special cases and provides a NumPy-compatible interface.

…/linalg.py contains functions like cholesky, svd, qr, eig that wrap lower-level primitives. Some have custom JVPs for differentiation.

…/reductions.py defines reductions like sum, prod, mean by calling _reduction(). It handles optional arguments, masking, and upcasting float16.

…/setops.py implements set operations like unique(), union1d(), intersect1d() for JIT compatibility using sorting and padding.

…/index_tricks.py provides classes for indexing tricks like mgrid, ogrid, concatenating arrays.

…/fft.py contains FFT functions like fft, ifft, fft2 implemented using XLA. It handles normalization and shifting spectra.

…/vectorize.py defines the vectorize() function for vectorizing functions over array axes. It uses vmap() internally.

…/util.py contains utilities for shape/dtype promotion, wrapping NumPy functions, and implementing operations like where().

…/array_methods.py dynamically adds NumPy-like methods to JAX arrays. It provides syntactic sugar for index updates using .at[].

Core Array Operations

References: jax

The core array manipulation and shape utilities in JAX are implemented in the …/abstract_arrays.py file.

The UnshapedArray class represents an unshaped array value. It defines the core operations for manipulating arrays, including indexing, slicing, and broadcasting. The broadcast method handles broadcasting arrays to a common shape.

The ShapedArray class represents a shaped array value with a shape and dtype. It defines methods for accessing the shape and dtype of the array.

The eval_shape function evaluates the shape of an array without executing the full computation. This is useful for debugging shape errors.

Linear Algebra

References: jax

Matrix decompositions, solvers, eigenvalues, norms, and products are handled in JAX through the …/linalg.py file. This file imports linear algebra functions from …/linalg.py, including:

  • Decomposition methods like cholesky(), qr(), and svd() for Cholesky, QR, and singular value decompositions.

  • Eigenvalue solvers like eig() for finding eigenvalues and eigenvectors.

  • Norm functions like norm() for computing vector and matrix norms.

  • Solving linear systems with functions like solve() and triangular_solve().

  • Matrix products through functions like dot() for matrix-matrix and matrix-vector products.

The …/linalg_benchmark.py file contains benchmarks for many of these linear algebra functions. It defines benchmarks for:

It benchmarks these functions on random matrices of varying sizes to measure performance.

The …/solver.cc file contains GPU implementations of linear algebra solvers like LU, QR, and EVD. It defines kernels for:

  • Getrf() for LU decomposition
  • Orgqr() for QR decomposition
  • Syevd() for symmetric eigenvalue decomposition

These kernels wrap cuSOLVER calls and handle descriptors, streams, and errors.

Fourier Transforms

References: jax

Fourier transforms are handled in the …/ducc_fft.cc and …/ducc_fft_kernels.cc files. These files contain implementations of FFT kernels using the DuccFFT library.

The DuccFft and DynamicDuccFft functions define the main FFT functions, while BuildDynamicDuccFftDescriptor builds descriptors for dynamic FFTs.

The DuccFft function performs FFTs by calling the corresponding DuccFFT routine. It handles packing and unpacking data from JAX arrays into DuccFFT arrays, as well as setting up the DuccFFT plan.

The DynamicDuccFft function performs dynamic FFTs where the input size is only known at runtime. It builds a dynamic DuccFFT plan using BuildDynamicDuccFftDescriptor, which constructs the descriptor based on the input size and other arguments.

The key implementation details are:

  • DuccFFT is used to perform the low-level FFT computations. It provides high performance CPU FFT implementations.

  • Descriptors are used to represent the FFT problem configuration. They contain information like the input/output sizes and strides.

  • Descriptors for dynamic FFTs are built in BuildDynamicDuccFftDescriptor. They contain an additional "max_size" parameter for the maximum possible input size.

  • Plans are created from descriptors and cached for reuse. Dynamic plans are re-built when the input size changes.

  • Data is copied between JAX arrays and DuccFFT arrays during the FFT. This involves unpacking strides and padding for higher performance.

  • Errors are handled by checking the DuccFFT error code after each call and raising an error.

  • The BuildDynamicDuccFftDescriptor function plays a central role in enabling dynamic-sized FFTs, which are more flexible.

  • The DynamicDuccFft function implements the actual dynamic FFT by building a plan from the descriptor, performing the FFT, and cleaning up the plan.

Parallel Operations

References: jax

Collectives like sums, gathers over mapped axes are handled by the PmapExecutable class defined in …/pxla.py. PmapExecutable represents an XLA computation that has been parallelized for execution across multiple devices. It contains metadata like the logical axes, input/output sharding, and the underlying XLA computation.

The parallel_callable function in pxla.py compiles a callable for parallel execution by lowering to a PmapExecutable. It handles sharding arguments and outputs across devices, and inserting all-gather and all-reduce collectives as needed. These collectives are implemented using XLA HLO instructions.

During execution, PmapExecutable's execute method partitions inputs across devices, executes the underlying XLA computation on each device, and performs the necessary collectives to gather outputs and reduce values across devices.

The Sharding class represents a sharding specification, containing a mesh shape and a list of sharding dimensions. It is used to specify how to shard inputs and outputs across devices.

The create_mesh function creates a mesh of devices by specifying the mesh shape and device assignment function. This mesh is used during parallel execution to determine which devices correspond to which mesh coordinates.

So in summary, parallel operations like sums and gathers over mapped axes are handled by the PmapExecutable class, which lowers parallel mappings to XLA computations containing the necessary collectives. The parallel_callable function compiles callables for parallel execution, while Sharding and create_mesh specify how data is partitioned across devices.

Control Flow Primitives

References: jax/_src/lax/control_flow

JAX's control flow primitives allow expressing loops, conditionals, and custom solvers in a functional style. They leverage JAX's ability to translate Python control flow into efficient linearized computations by staging functions as JAXPRs.

for_loop allows writing loops with read/write semantics using Refs. It traces the loop body to a Jaxpr, handles partial evaluation via a fixpoint, and supports batching/vectorization. Residual values are handled by converting outputs to writes into Refs.

cond conditionally applies true/false functions based on a predicate. It traces both branches and handles batching/vectorization, partial evaluation, and joining outputs across branches. Common subexpressions are extracted using _initial_style_jaxprs_with_common_consts.

switch applies an indexed branch function. It traces all branches for transformations like batching and vectorization.

custom_root and custom_linear_solve allow expressing iterative solvers and root finders by scanning user functions. They are implemented via primitives with custom JVP rules defined.

for_loop traces the loop body to a Jaxpr using _trace_to_jaxpr_with_refs. It hoists constants to Refs using _hoist_consts_to_refs to allow mutation. Partial evaluation uses _for_partial_eval which runs a fixpoint to determine unknown Refs. Residual values are handled by converting outputs to writes into Refs.

_cond traces both branches to Jaxprs using _initial_style_jaxprs_with_common_consts which unifies constants. This allows it to be converted to a select with a batched predicate. _join_cond_outputs joins outputs across branches.

custom_root's forward pass calls the solve Jaxpr, while its JVP computes the tangent of the solve function. custom_linear_solve's forward pass is a call to a primitive, while its JVP computes the RHS of the tangent equation.

Neural Networks

References: jax/nn

Neural network functionality in JAX is provided via reusable classes for layers, initializers, and activation functions.

The …/__init__.py file provides common activation functions like relu, sigmoid, and softmax via re-exporting from jax._src.nn.functions. It also re-exports the initializers submodule, which contains functions for initializing layer weights.

The …/initializers.py file defines initializers as callable classes that implement different initialization strategies. The Initializer base class standardizes the interface. Initializers like normal leverage JAX RNG to sample from distributions, while orthogonal uses QR decomposition.

The file contains initializers for initializing neural network weights, implementing strategies like Xavier and He in a vectorized way. The Initializer base class standardizes the interface.

In summary, this directory provides common activation functions and initializers as reusable components for neural network layers.

Automatic Differentiation

References: jax

Automatic differentiation in JAX is implemented using the grad() and vjp() functions.

grad() computes the gradient of a function by performing reverse-mode automatic differentiation. It takes a function and the inputs to differentiate and returns the gradient.

vjp() computes the vector-Jacobian product which can then be used to efficiently compute gradients of functions with many outputs. It returns both the function output and a "vjp function" that computes the Jacobian-vector product when applied to a tangent vector.

The main implementation is in the …/ad_util.py file:

  • The grad() function handles tracing the input function using JVPTracer from ad.py. This records primal and tangent values during the trace. It then replays the trace in "reverse" mode, accumulating the Jacobian-vector products to compute the gradient.

  • The vjp() function similarly traces the input function, but returns both the primal output and a "vjp function" that computes the Jacobian-vector product when applied to a tangent vector.

  • The jvp() function computes the primal and tangent outputs of a function when given a tangent vector for its input. It replays the trace using JVPTracer.

  • The JVPTracer class records primal and tangent values during a trace. It handles propagating tangent values through primitives using their jvp rules.

So in summary, automatic differentiation in JAX is implemented by tracing functions, recording primal and tangent values, then replaying the trace in "forward" or "reverse" mode to compute derivatives. Functions like grad() and vjp() provide a high-level interface, while JVPTracer handles the low-level tracing and tangent propagation.

Compilation and Parallelism

References: jax/_src/lib, jax/interpreters

The PmapExecutable class is the primary way of representing parallel XLA computations that have been lowered from JAX. It contains the underlying XLA Computation, logical axis names, and metadata like operation shardings. The parallel_callable function compiles a callable for parallel execution by lowering it to a PmapExecutable. It handles sharding arguments and collecting results across devices.

The shard_arg function shards an argument array according to a ShardingSpec defined in jax._src.sharding_specs. Classes like Chunked, Replicated specify how the array should be partitioned across devices. The spec_to_indices function converts a sharding spec to device indices.

The xla_pmap_p function executes a PmapExecutable via XLA, mapping logical to physical devices. It uses thread_resources from jax._src.mesh to represent thread resources like devices.

The apply_primitive function in xla.py is responsible for lowering each JAX primitive to XLA during execution. It dispatches based on the primitive name and arguments, driving how JAX computations are staged into XLA.

The Backend class represents an XLA compiler backend like CPU/GPU that compiles and executes the emitted XLA computations. This is how the lowered XLA computations produced by JAX are ultimately run.

Compilation

References: jax/_src/lib, jax/interpreters

JIT compilation of functions in JAX is implemented using the apply_primitive function in …/xla.py. This function is responsible for lowering each JAX primitive into XLA during execution. It dispatches based on the primitive name and arguments, which drives how JAX computations are staged into XLA.

The Backend class represents an XLA compiler backend like CPU/GPU that compiles and executes the emitted XLA computations. When a primitive is lowered, the backend compiles it into a XLA computation, which is then cached. On subsequent calls, the cached XLA computation is used, realizing just-in-time compilation.

The PmapExecutable class in …/pxla.py also plays an important role in JIT compilation. It represents a parallel mapping computation that has been lowered to an XLA HLO computation. The parallel_callable function compiles a callable for parallel execution by lowering it to a PmapExecutable. This compiles the function into an XLA computation that can be executed in parallel across devices.

In summary, JIT compilation in JAX is realized by lowering JAX primitives into XLA computations via apply_primitive, compiling those XLA computations into executable functions using the Backend, and caching the results. The PmapExecutable represents parallel mappings that have been JIT compiled for multi-device execution.

Parallelism

References: jax/_src/lib, jax/interpreters

The core functionality for parallelizing JAX computations across devices is contained in the …/pxla.py file. The PmapExecutable class represents a parallel mapping computation that has been lowered to XLA HLO. It contains the XLA Computation, local/global axis names, and metadata like op shardings. The parallel_callable function compiles a callable for parallel execution by lowering to a PmapExecutable. It handles sharding arguments and collecting results.

The shard_arg function shards an argument array according to a ShardingSpec. Classes define sharding specifications for arguments. The xla_pmap_p function executes a PmapExecutable via XLA, handling the mapping of logical to physical devices.

The PmapExecutable class is the primary way of representing and executing parallel mappings. It contains the lowered XLA computation, axis names, and sharding metadata. The parallel_callable function compiles a callable for parallel execution by lowering it to a PmapExecutable. It handles:

  • Sharding arguments using shard_arg
  • Collecting results

The shard_arg function shards an argument array according to a ShardingSpec object, which defines how the array should be partitioned across devices.

The xla_pmap_p function executes a PmapExecutable via XLA, mapping logical to physical devices.

Vectorization

References: jax/_src/lib, jax/interpreters

Vectorization in JAX is implemented using the vmap function defined in …/batching.py. vmap handles vectorizing functions over array axes by applying batching rules.

The BatchTrace class represents a trace of a vectorized computation. It contains attributes like the input batch axes and nested subtraces for function calls.

The BatchTracer is used during tracing to insert vectorization rules at each primitive call. It applies the appropriate BatchingRule for the primitive, which specifies how the primitive should be vectorized.

Functions like vectorized_batcher implement default vectorization rules by returning a MapSpec object. This MapSpec describes how arguments should be mapped over batch axes for vectorization.

The register_vmappable function allows custom primitives to register their own BatchingRule implementation, specifying a custom vectorization rule.

The defvectorized function can be used to define a custom vectorization rule for a user-defined function. This specifies how the function should be vectorized when passed to vmap.

In summary, the key components are the vmap function itself, the BatchTrace and BatchTracer classes used during tracing, and the BatchingRule interface which allows primitives and functions to specify their vectorization behavior.

Scientific Computing

References: jax/scipy

JAX's SciPy library provides optimized implementations of functionality from SciPy, scikit-learn, and other scientific Python packages. This allows leveraging JAX transformations like autodiff and JIT compilation within scientific computing libraries.

The core functionality is provided by submodules in …/scipy that implement functionality from their SciPy counterparts. This includes:

  • …/stats which provides probability distributions and statistical functions through an interface similar to SciPy stats. It contains distributions like norm, beta, binom, and functions like gaussian_kde.

These submodules import lower-level JAX implementations and re-export functionality through a SciPy-like API. They provide optimized versions of SciPy algorithms that leverage JAX.

The /__init__.py file for each submodule acts as an import hook, re-exporting functionality to be accessed directly from jax.scipy. This provides a unified API.

Functions generally implement algorithms directly in a procedural style without classes. They take advantage of JAX to perform calculations efficiently on hardware accelerators.

gaussian_kde performs kernel density estimation with Gaussian kernels. It computes the kernel density estimate on an array of data.

SciPy Functionality

References: jax/scipy

Specialized SciPy functionality like linear algebra, FFTs, clustering, and other algorithms have been ported to work with JAX, enabling these techniques to take advantage of autodiff, GPUs, and TPUs. This includes:

  • Linear algebra routines for matrix decompositions, solvers, and operations on large matrices. Functions like cholesky(), lu(), and qr() compute matrix decompositions while solve() solves systems of linear equations.

  • FFT implementations through functions like dct(), idct(), and fftconvolve() for performing fast Fourier transforms.

  • Clustering algorithms in vq() for vector quantization and k-means clustering. vq() initializes centroids, assigns vectors, and updates centroids iteratively until convergence.

  • Sparse linear solvers like cg(), gmres(), and bicgstab() that use iterative methods to solve large sparse systems. They leverage sparse matrix-vector products and have different convergence properties.

  • Special functions through logpdf(), pdf(), cdf(), and others for probability distributions, Bessel functions, and more. Functions like gamma() and gammaln() compute the gamma function and its logarithm.

These SciPy algorithms have been ported to JAX while maintaining a similar interface, allowing them to take advantage of JAX's transformations for optimization, autodiff, and hardware acceleration.

Additional Scientific Libraries

References: jax

Other scientific Python libraries now compatible with JAX provide specialized functionality for scientific computing tasks beyond the core SciPy implementations. These libraries expose JAX-optimized versions of algorithms from packages like TensorFlow Probability and PyMC. They provide a unified API on top of JAX while leveraging lower-level primitives.

The …/onnx2xla.py file shows compiling an ONNX model to XLA using JAX by tracing an ONNX interpreter. It implements common ONNX operators on NumPy and compiles the interpreter using XLA to accelerate execution. This allows using ONNX models with JAX.

The …/sparse_benchmark.py file benchmarks core sparse linear algebra operations from the JAX sparse library like conversion between dense/sparse formats and sparse matrix-vector multiplication. These benchmarks exercise JAX's compatibility with sparse scientific computing.

The …/datasets.py file contains functions for downloading and preprocessing common machine learning datasets. It provides a standardized API for loading data into JAX, demonstrating how JAX can be used for scientific tasks involving large datasets.

Debugging and Profiling

References: jax/tools

JAX provides a number of utilities for debugging and profiling code. At a high level:

  • The grad and jacfwd/jacrev functions in …/__init__.py allow taking gradients and Jacobians of functions for debugging.

  • The profile function instruments functions for performance profiling, showing timing and memory usage.

  • The check_grads function numerically estimates gradients to verify the correctness of gradient functions.

  • The jax_to_ir tools in …/jax_to_ir.py convert JAX functions to HLO or TensorFlow for debugging.

JAX debugging and profiling functionality is implemented in several key files:

Overall, these utilities provide essential tools for:

  • Checking gradients numerically
  • Profiling performance bottlenecks
  • Debugging JAX functions by converting them to other IRs

Debugging Utilities

References: jax/tools

This code provides several utilities for debugging JAX programs and inspecting intermediate representations.

The jax_to_ir() function in …/jax_to_ir.py is useful for debugging. It converts a JAX function into an intermediate representation (IR) in either the HLO or TensorFlow format. This allows inspecting the IR to debug issues in the generated computation.

The jax_to_hlo() and jax_to_tf() helpers provide a simple interface for converting to the HLO or TF formats. The conversion process involves:

  • Currying constants using functools.partial
  • Wrapping the function to control argument order
  • Running the function to produce an XlaComputation
  • Serializing the computation to the desired IR format

The main() function parses command line arguments to specify the input function, shapes, constants, and output location. It then calls jax_to_ir() to perform the actual conversion.

The parse_shape_str() function helps parse shape strings like "f32[2,3]" into ShapedArray objects needed for input shapes.

Overall, jax_to_ir() provides a useful debugging tool by exposing the intermediate representations generated by JAX, allowing issues to be identified and fixed.

Profiling and Tracing

References: jax/tools

The profile function instruments a JAX function to collect profiling information when run. It returns a Profile object containing timing and memory usage stats. This allows developers to identify bottlenecks and optimize performance.

Gradient Checking

References: tests

Functions for numerically checking gradients compute the gradient of functions using forward and reverse mode automatic differentiation, and compare the results to verify the correctness of the gradients. This ensures gradients are computed accurately before using them for optimization.

The check_grads function is central to gradient checking. It calls jax.grad to compute the gradient using forward mode autodiff, and jacrev to compute the gradient using reverse mode autodiff. It then compares the results, checking they are close within a provided tolerance.

The LaxAutodiffTest test class contains many test methods that use check_grads to verify the gradients of LAX operations. These tests are parameterized over different operations, argument shapes/types, and test configurations using @parameterized.parameters. This allows testing a wide range of cases in a maintainable way.

The check_grads function takes the function to differentiate, the input, the gradient function to use (jax.grad or jacrev), and a tolerance. It calls the gradient function to compute the gradient in both forward and reverse mode. It then compares the results within the provided tolerance.

The tolerances are tailored to different dtypes to account for numerical precision differences. Tests are also skipped in some cases for certain devices like TPU where some operations are imprecise.

Extensions and Customization

References: jax/extend

The …/mlir subdirectory allows lowering JAX computations to the MLIR intermediate representation and transforming them using MLIR passes. This enables optimizing JAX programs using the MLIR compiler infrastructure.

The …/core.py file exposes types and constants from JAX's core for external use. The array_types object allows depending directly on JAX's abstract array types.

The …/linear_util.py file contains utilities for caching linear transformations, combining transformations, and wrapping functions to track applied linear transformations. The WrappedFun class is central to this functionality.

The …/passmanager.py file provides the PassManager class for constructing and executing sequences of MLIR optimization passes. This allows transforming MLIR modules representing JAX computations.

The …/ir.py file imports MLIR IR classes from the JAX library, allowing JAX to directly manipulate MLIR.

In summary, this directory provides controlled ways to extend JAX internals by exposing lower-level primitives. It also integrates MLIR to enable optimizing JAX programs. The WrappedFun class and MLIR integration are central to these capabilities.

Extending JAX

References: jax/extend

JAX provides several mechanisms for extending its functionality in a controlled way. The …/extend directory contains utilities that expose lower-level JAX primitives and internals.

The WrappedFun class in …/linear_util.py is central to extending JAX. It wraps a function and tracks any linear transformations applied to it. This allows caching transformed functions and combining transformations. The cache() function memoizes results, and merge_linear_aux() combines auxiliary data from multiple transformations.

The NameStack class in …/source_info_util.py uses a stack to efficiently track nested location contexts. As code moves through transformations, the NameStack is manipulated by functions like extend_name_stack() to keep track of where variables were originally defined.

The …/mlir directory lowers JAX computations to MLIR and exposes MLIR concepts to Python. The PassManager in …/passmanager.py constructs and executes sequences of MLIR optimization passes, transforming the MLIR.

The …/dialects files expose MLIR dialects used during lowering. For example, …/chlo.py provides utilities for the Chlo dialect which represents JAX computations as MLIR operations.

In summary, these utilities provide a controlled way to extend JAX internals by exposing lower-level primitives. Classes like WrappedFun and NameStack are central to implementing capabilities like tracking linear transformations and propagating source locations.

MLIR Integration

References: jax/extend/mlir, jax/extend/mlir/dialects

JAX integrates MLIR into its compiler pipeline to enable optimization of JAX programs via the MLIR intermediate representation. The key functionality is lowering JAX computations to MLIR IR and transforming the IR using MLIR passes.

Lowering is done by first converting JAX HLO to the MLIR Chlo dialect using utilities in …/__init__.py. The MLIR IR can then be manipulated using classes imported from …/ir.py.

Transformation and optimization of the MLIR occurs via sequences of passes managed by the PassManager class defined in …/passmanager.py. This provides a high-level interface for running individual optimization, analysis and conversion passes on the MLIR module.

Files like …/arith.py expose dialects like arithmetic to be used when lowering JAX to MLIR.

…/__init__.py registers dialects with MLIR so they can be used as IRs for JAX.

In summary, the MLIR integration code lowers JAX computations to MLIR IR and transforms the IR using passes managed by PassManager to optimize JAX programs. The dialects exposed allow representing JAX operations as MLIR.

Adding Custom Primitives

References: jax/extend/core.py, jax/extend/linear_util.py

Exposing additional primitives from JAX's internals involves extending JAX's core library to expose new primitives for use in transformations. This allows users to define custom operators that JAX can differentiate, compile, and optimize.

The WrappedFun class is central to exposing new primitives. It allows tracking how functions change when transformations are applied. By wrapping a custom primitive function in WrappedFun, it can be differentiated, compiled, and combined with other primitives.

The transformation and transformation_with_aux functions are used to define new transformation primitives. A user can define a custom transformation by writing a function that returns the result of transformation or transformation_with_aux.

The cache function is useful for caching the results of applying transformations to custom primitives. This avoids recomputing the same transformation multiple times.

The merge_linear_aux function combines the auxiliary data from applying multiple transformations to a custom primitive. Since applying several transformations sequentially will produce auxiliary data from each step, this must be merged.

The wrap_init function can be used to wrap custom initializers so they also track linear transformations, allowing them to be used as part of custom primitives.

In summary, by leveraging functions like WrappedFun, transformation, cache, and wrap_init, a user can define new primitives that JAX can differentiate, compile, and optimize just like its built-in operators.

Testing

References: tests

Comprehensive test suites for JAX functionality are implemented in several ways. The main approach is defining test case classes that inherit from jtu.JaxTestCase. These classes contain test methods that exercise different aspects of JAX functionality.

The test methods are often decorated with @jtu.sample_product to parametrize the tests over different shapes, dtypes, arguments, and configurations. This allows testing a wide range of cases in a maintainable way.

Functions like _CheckAgainstNumpy() are used to compare JAX results to NumPy for correctness. _CompileAndCheck() compiles and checks the generated XLA to validate performance.

Classes like LaxVmapOpTest contain logic to systematically test vectorizing operations over batch dimensions. Methods like _CheckBatching() call the operation on sliced batches to get the expected result, then compare to the vectorized result.

Test functions are often jitted using jax.jit() to ensure tests exercise the compiler backend. Functions like make_jaxpr() are used to inspect the intermediate representations.

Utilities like assertArraysEqual() and check_grads() are used to validate outputs and gradients respectively.

Tests are parameterized over different XLA versions and device types to cover a wide range of hardware. Special tolerances are used for certain devices.

Some important classes:

LaxAutodiffTest - Tests gradients of LAX primitives. Uses GradTestSpec to parameterize tests.

SparseSolverTest - Tests sparse linear solvers. Uses properties to test many cases.

CustomObjectTest - Tests custom objects lowering. Checks lowering to nothing.

MultiDeviceTest - Contains tests for multi-device functionality. Checks device placement.

Test Suites

References: tests

The LaxAutodiffTest class contains many test methods that test the gradients of LAX primitives. These tests are parameterized over different LAX operations, argument shapes/types, and test configurations.

The gradients are tested using the check_grads utility function, which computes gradients using forward and reverse mode automatic differentiation and compares the results. This ensures the gradients computed via both modes match.

The GradTestSpec and GradSpecialValuesTestSpec named tuples are used to specify the gradient tests in a standardized way. They capture things like the operation to test, argument shapes/types, test configuration, and expected tolerances.

The main logic lives in the check_grads function. It calls jax.grad and jax.jacrev to compute gradients via forward and reverse mode respectively. It then compares the results and checks they are close within the provided tolerance.

The tests are parametrized over many argument shapes and dtypes using @parameterized.parameters decorators. This covers a wide range of cases and helps ensure correctness.

FileCheck Tests

References: tests/filecheck

Tests for lowering JAX computations into MLIR intermediate representation are implemented in several files in the …/filecheck directory. These tests use FileCheck to validate that the MLIR output from lowering JAX matches expectations.

The print_ir() function from …/jax_filecheck_helpers.py is used to print the MLIR for JAX computations. It compiles functions with jax.jit() and calls the lower() function to get the MLIR, then prints it.

Several files contain tests that call print_ir() on JAX functions to print the MLIR, then use FileCheck to validate the IR.

…/array.filecheck.py contains tests for lowering various JAX array operations to MLIR. It calls functions like lax.concatenate and passes them to print_ir() to print the resulting MLIR, then checks that MLIR matches patterns.

…/shapes.filecheck.py and …/math.filecheck.py contain similar tests for lowering JAX operations on different shape/type combinations and for elementwise math operations respectively.

…/names.filecheck.py checks that the MLIR modules are named correctly based on the JAX computation.

…/subcomputations.filecheck.py contains tests for functions like mlir.merge_mlir_modules that manipulate the MLIR.

…/custom_call.filecheck.py tests the mlir.custom_call() function by printing MLIR from calls with different configurations and checking the IR.

These tests provide comprehensive coverage of lowering JAX computations to MLIR, validating the compiler backend.

Third Party Tests

References: tests/third_party

Compatibility tests vs. SciPy, NumPy, etc. focus on thoroughly testing that JAX implementations of functionality from third party libraries produce the same results as the original implementations. This ensures JAX can be a drop-in replacement.

The main focus is in …/scipy, which contains tests for JAX's SciPy compatibility. The key test is in …/line_search_test.py which validates the line_search function that performs line searches for optimization algorithms.

The TestLineSearch class contains many tests for line_search. It defines scalar and multi-dimensional test functions, and tests that line_search finds steps that satisfy the Wolfe conditions on these functions. There are generic tests that try different scalar and line search problems. It also tests bounds handling and compares results to SciPy's line search.

The tests cover a wide range of scenarios, inputs, parameters and edge cases to thoroughly validate that line_search produces the same results as SciPy's implementation. This includes:

  • Generic tests that try different scalar and line search problems
  • Tests that ensure line_search finds steps satisfying the Wolfe conditions
  • Tests of bounds handling via the maxiter argument
  • Comparisons of results to SciPy's line search
  • Tests using different starting points

By thoroughly testing line_search across many scenarios and comparing to SciPy, the tests ensure JAX's SciPy compatibility for this key optimization functionality. This rigorous testing strategy is used throughout …/third_party to validate JAX implementations of functionality from NumPy, SciPy and other libraries.

Pallas Tests

References: tests/pallas

The …/pallas directory contains comprehensive tests for the Pallas compiler. The PallasTest base test case handles GPU/Triton configuration and caching of compiler outputs.

The main test classes are:

The tests validate correctness, gradients, and vectorization of Pallas compiled functions. Utilities like pl.load, pl.store, and slicing are used extensively.

…/indexing_test.py contains property based tests using Hypothesis to validate the NDIndexer class. Strategies generate a wide range of shapes and indexers. The tests ensure the integer indexer shape matches expectations and the NDIndexer implementation meets its behavior.

…/all_gather_test.py tests the all_gather collective by sharding arrays across devices and meshes. It randomly generates array shapes and dtypes, shards arrays across 1D and 2D meshes, runs the all_gather, and checks equality with the unsharded array.

Building and Packaging

References: build

Scripts and utilities for building JAX from source and packaging it handle the following tasks:

  • Building the JAX library from source code using the …/build.py script. This script handles downloading dependencies like Bazel, checking environment variables, writing Bazel configuration files, and executing the Bazel build commands to compile JAX.

  • Running JAX's test suites using the …/parallel_accelerator_execute.sh script, which coordinates parallel test execution across multiple accelerators.

  • Building documentation using scripts in the jax directory.

  • Building JAX with support for AMD ROCm GPUs using scripts in …/rocm. This includes building a Docker container with ROCm dependencies, the main ROCm build script, and scripts for multi-GPU testing.

The …/build.py file contains the main logic for building JAX. It has functions like:

The main() function parses arguments, writes the Bazel config, then executes the Bazel build command to compile JAX. It can optionally build GPU plugins by passing different Bazel targets.

The …/parallel_accelerator_execute.sh script coordinates parallel test execution. It:

  • Gets environment variables for the number of accelerators and tests per accelerator
  • Finds the test binary using rlocation()
  • Acquires a lock file for an accelerator/test slot combination using flock
  • Sets environment variables like CUDA_VISIBLE_DEVICES within a subshell
  • Runs the test binary

Scripts in …/rocm build and test JAX with ROCm GPU support. The …/ci_build.sh script builds a Docker container containing ROCm dependencies. It:

  • Parses command line arguments
  • Builds the Docker image using the docker build command, calling either the "rt_build" or "ci_build" target
  • Runs commands inside the container
  • Optionally commits the container as a new image

Building JAX

References: build/build.py, build

Building JAX involves building the libjax dependency from source using Bazel. The main script for this is …/build.py, which handles downloading Bazel, setting environment variables, writing the Bazel configuration file, and executing the Bazel build command.

The get_bazel_path() function will find the Bazel binary either from the provided path, the PATH, or by downloading and verifying a Bazel binary of the correct version.

The write_bazelrc() function generates the Bazel configuration file by taking all the build options as arguments and writing them as flags. This sets paths for things like CUDA, ROCM, and build configurations.

Command line arguments are parsed to control build options like enabling CUDA/ROCM, CPU flags, and paths to libraries. The add_boolean_argument() function handles boolean flags.

The main logic in main() validates options, prints environment, writes the Bazel config file, then executes the Bazel build command via shell() to build libjax. It can optionally build GPU plugins by passing different Bazel targets.

The …/build_rocm.sh script handles building JAX with ROCm support. It clones the XLA repository if needed, sets environment variables, and builds and installs the JAX wheel with ROCm enabled.

The …/parallel_accelerator_execute.sh script coordinates parallel execution of Bazel tests across multiple GPUs/TPUs by assigning each test to an accelerator.

Testing JAX

References: build/parallel_accelerator_execute.sh, build

The build directory contains scripts for running JAX's comprehensive test suites and CI. These scripts coordinate parallel execution of tests across multiple accelerators like GPUs and TPUs.

The main script is …/parallel_accelerator_execute.sh, which coordinates access to accelerators for concurrent Bazel tests. It assigns each test to an accelerator to ensure even distribution across devices. When run with Bazel using the --run_under flag, it will execute tests in parallel on the available accelerators.

The script acquires lock files using flock to reserve accelerators, then sets environment variables like CUDA_VISIBLE_DEVICES and TPU_VISIBLE_CHIPS within a subshell to control the target device. It implements a basic resource reservation scheme to coordinate parallel test execution across multiple accelerators.

The …/rocm directory contains scripts for building and testing JAX with support for AMD ROCm GPUs. The run_multi_gpu.sh script detects the number of AMD GPUs present and runs JAX tests that require multiple GPUs, configuring the tests to use the available GPUs. The ci_build.sh script builds a Docker image containing the necessary dependencies to build and test JAX for ROCm.

The build.py script handles executing Bazel test commands to run JAX's test suites. It parses command line arguments to control options like enabling GPUs. The main() function executes the Bazel test command via shell() to run the tests.

These scripts provide an automated and parallelized mechanism for running JAX's comprehensive test suites and CI across multiple accelerators. They coordinate access to devices, build necessary environments, and execute the test commands with the proper options.

Building Documentation

References: jax

Scripts for building JAX's documentation are handled by utilities in the build directory. The main script is …/build.py, which contains logic for building the JAX library from source code. It handles downloading dependencies, configuring builds, and executing test commands.

The build.py file contains several important classes and functions:

The script executes the Bazel build command to build JAX's dependencies and documentation. Bazel targets like //docs:docs build the Sphinx documentation, while //docs:html generates the HTML output.

The …/parallel_accelerator_execute.sh script coordinates parallel execution of Bazel tests across multiple GPUs/TPUs. It uses lock files to reserve accelerators for each test slot, then runs the test within a subshell after setting environment variables to control the target device.

Functions in the script include:

  • rlocation() - Acquires a lock file for an accelerator/test slot combination
  • run_test() - Runs a test within a subshell after acquiring a lock

The …/rocm subdirectory contains scripts specific to building JAX's documentation for AMD ROCm GPUs. These scripts configure the build and test environment for ROCm.

In summary, the build scripts in build handle building JAX's documentation by executing the Bazel build command after configuring dependencies and the build environment. They coordinate parallel testing and provide functionality specific to different hardware.

ROCm Support

References: build/rocm

The …/rocm directory contains scripts and utilities for building and testing JAX with support for AMD ROCm GPUs. The main scripts are:

  • …/build_rocm.sh: This script handles building JAX with ROCm support by setting environment variables, building the XLA dependency, and running the JAX build with the proper ROCm configuration. It first checks if the XLA_CLONE_DIR environment variable is set, and if so uses that path for the XLA repository instead of cloning. If XLA_CLONE_DIR is not set, it will clone the XLA repo from GitHub with default values. It then exports the ROCM version and runs the JAX build script with the ROCm and XLA options, installing the built wheel.

  • …/ci_build.sh: This script builds a Docker image containing the necessary dependencies to build and test JAX for ROCm. It parses command line arguments, sets variables, and builds the Docker image calling either the "rt_build" or "ci_build" target. It then runs commands inside the built Docker container, optionally committing the container as a new image.

These scripts reuse common functions defined in …/build_common.sh like die for error handling and calc_elapsed_time for timing. The die function in particular standardizes error handling across scripts.

The …/run_multi_gpu.sh script detects the number of AMD GPUs present and configures JAX tests requiring multiple GPUs to run on the available devices by setting the HIP_VISIBLE_DEVICES environment variable.

Docker Container Build

References: build/rocm/ci_build.sh

The …/ci_build.sh script handles building a Docker container with the necessary ROCm dependencies for compiling and testing JAX packages. The script parses command line arguments to determine the Dockerfile path, Python version, and other build options. It then sets variables like WORKSPACE, BUILD_TAG, and DOCKER_IMG_NAME used throughout the build.

The main functionality is building the Docker image using either the "rt_build" or "ci_build" target, running commands inside a container launched from that image, and optionally committing the container as a new image. This allows testing JAX in the ROCm environment.

The docker build command builds the Docker image, calling the appropriate target based on RUNTIME_FLAG. It passes arguments for the Python version and ROCm paths. The docker run command launches a container from the built image, mounting the workspace, setting environment variables, and running the provided command. It uses ROCM_EXTRA_PARAMS for device mounting.

If KEEP_IMAGE is set, the script will commit the container as a new image after the command finishes successfully. The upsearch function finds the WORKSPACE, and command line arguments are parsed with a case statement on $1.

ROCm Build Script

References: build/rocm/build_rocm.sh

The …/build_rocm.sh script is the main way to build JAX with support for AMD's ROCm GPUs. It handles configuring the environment and dependencies needed to build JAX with ROCm, and then runs the JAX build process with the proper ROCm-specific options.

The script first checks if the XLA repository has already been cloned by checking the XLA_CLONE_DIR environment variable. If set, it will use that directory, otherwise it will clone the XLA repository from GitHub based on the XLA_REPO and XLA_BRANCH variables.

It then extracts the ROCm version from /opt/rocm/.info/version and exports it as JAX_ROCM_VERSION.

The main part of the build is running python3 ./build/build.py, passing the ROCm path. This builds a JAX wheel with ROCm support.

The script then installs this built wheel locally, as well as installing JAX itself.

For CI, it outputs the installed JAX version and ROCm version to files so they can be verified.

Multi-GPU Testing

References: build/rocm/run_multi_gpu.sh

This script allows JAX tests that require multiple GPUs to run on AMD/ATI GPU systems with different numbers of GPUs. It detects the number of AMD/ATI GPUs present using lspci and then sets the HIP_VISIBLE_DEVICES environment variable accordingly to specify which GPU devices should be used for the tests. For systems with 8 or more GPUs, all 8 GPUs will be used. For 4 to 7 GPUs, the first 4 GPUs will be used, and so on. It then runs the pmap_test.py and multi_device_test.py test modules to execute the relevant tests using the specified GPUs.

The lspci command is run and the number of AMD/ATI GPU controllers is saved in the $cmd variable. An if/elif/else block is used to set the HIP_VISIBLE_DEVICES environment variable based on the number of GPUs detected. For 8 or more GPUs, HIP_VISIBLE_DEVICES is set to 0,1,2,3,4,5,6,7. For 4 to 7 GPUs, it is set to 0,1,2,3, and so on.

The python3 -m pytest command is run with the --reruns 3 -x tests/pmap_test.py arguments to execute the pmap_test.py module tests using the GPUs specified. This command is run again to also test the multi_device_test.py module.

Documentation

References: docs

The docs directory contains documentation files that provide tutorials, guides, design proposals, and reference material for JAX.

…/tutorials contains in-depth tutorials demonstrating key JAX capabilities. Some important files are:

…/notebooks contains Jupyter notebooks that illustrate JAX techniques. Many important notebooks demonstrate techniques like:

  • Custom differentiation rules
  • Distributed arrays and automatic parallelization
  • Neural network training

…/jep contains JAX Enhancement Proposals for major features. They discuss designs for:

  • New primitives and transformations
  • Type safety and type promotion semantics
  • Random number generation
  • Checkpointing and rematerialization

…/jax_extensions.py defines Sphinx extensions that add roles for:

The roles are registered in the setup function to be used in documentation.

Tutorials

References: docs/tutorials

The …/tutorials directory contains in-depth tutorials demonstrating key capabilities of JAX. The tutorials cover:

  • Automatic differentiation using jax.grad and higher derivatives. Gradients can be computed with respect to nested data structures.
  • Vectorization over batches using jax.vmap to automatically parallelize functions.
  • Debugging tools like jax.debug.print, jax.debug.breakpoint, and jax.debug.callback.
  • JIT compilation with jax.jit to speed up functions. Tracing is used to generate optimized XLA code.
  • Working with pytrees to represent model parameters and apply updates efficiently.
  • Random number generation using explicit keys that allow reproducibility and parallelism.

jax.jit compiles functions by first tracing them to produce a jaxpr representation. Tracers capture JAX operations without running them, producing a side-effect free jaxpr. Control flow is handled by only including the traced path.

jax.vmap automatically generates a vectorized version of a function by adding batch axes to inputs. It can vectorize only some arguments using in_axes. Functions can be composed, so a vmapped function can be jitted and vice versa.

jax.debug.print prints traced array values. jax.debug.breakpoint pauses execution to inspect values. jax.debug.callback registers a function to call during evaluation, making it compatible with transformations.

Notebooks

References: docs/notebooks

Runnable Jupyter notebooks demonstrate key JAX techniques for machine learning research. The notebooks illustrate concepts through examples and exercises, allowing readers to experiment with JAX.

Many of the notebooks train neural networks on benchmark datasets to show end-to-end examples. They leverage libraries like TensorFlow Datasets for input pipelines and data loading, while using JAX transformations for performance.

Some important notebooks:

The notebooks aim to build intuition for key JAX concepts like autodiff, JIT compilation, vectorization and parallelization. They provide a practical complement to the reference documentation.

Design Proposals

References: docs/jep

JEPs propose major new features for JAX, improvements to performance, correctness, usability, and more. Many introduce new primitives, transformations, or abstractions. Some rework existing functionality.

…/14273-shard-map.md introduces shard_map for multi-device programming. It controls device placement using a Mesh object and PartitionSpecs specify how inputs/outputs are split.

…/2026-custom-derivatives.md proposes custom_jvp_call and custom_vjp_call to customize differentiation. Users provide JVP/VJP rules. A custom_lin primitive is staged out to linear jaxprs, allowing processing VJP rules in two steps.

…/9263-typed-keys.md represents typed PRNG keys as extended dtypes using ExtendedDtype. The PRNGImpl class carries PRNG implementation metadata in each key's dtype.

…/11830-new-remat-checkpoint.md updates jax.checkpoint to allow custom rematerialization policies. It enables rematerializing constants, has lower Python overhead, and provides more features.

…/9407-type-promotion.md designs JAX's type promotion semantics. It analyzes options and chooses to constrain promotions to satisfy the lattice property for predictability while avoiding wider promotions.

Documentation Utilities

References: docs/sphinxext, docs/_templates, docs/_static

The …/sphinxext directory contains custom Sphinx extensions that add roles for linking to JAX GitHub issues and DOI URLs from documentation. The jax_extensions.py file defines the jax_issue_role and doi_role functions, which generate hyperlinks to JAX issues and DOI URLs respectively. The setup function registers these roles with Sphinx, making the "jax-issue" and "doi" roles available for use in documentation markup.

The jax_issue_role function implements the "jax-issue" role by:

  • Validating the issue number input
  • Constructing the GitHub issue URL from the validated input
  • Returning a nodes.reference node containing a hyperlink to the issue URL

The doi_role function implements the "doi" role by:

  • Unsescaping the DOI identifier input
  • Splitting the input on explicit title parts
  • Constructing the DOI URL from the identifier
  • Setting the hyperlink title from the given or default title
  • Returning a nodes.reference node containing the hyperlinked DOI URL

The layout.html template in …/_templates extends the "!layout.html" template to inherit the base HTML layout. It adds the …/style.css file to the list of CSS files included in the document head.

The …/style.css file:

  • Sets background colors for different documentation sections using selectors like .getting-started
  • Sets background colors for code blocks within certain divs using selectors like div.red-background pre
  • Defines a --block-bg-opacity CSS custom property used to set the opacity of semi-transparent background colors.

JAX 101

References: docs/jax-101

The core functionality covered under JAX 101 introduces the basics of JAX including JAX NumPy, automatic differentiation with jax.grad(), just-in-time compilation with jax.jit, vectorization with jax.vmap, and more advanced topics like random number generation, pytrees, parallelism, and state handling.

The file …/01-jax-basics.md shows how to get started with JAX NumPy, which behaves very similarly to NumPy but returns Jax array types. It demonstrates computing gradients of functions using jax.grad() and building simple machine learning training loops. An example linear regression model defines a model(), loss_fn(), and update() function used in gradient descent training.

jax.vmap() is used for vectorization across batch dimensions, as shown in …/03-vectorization.md. It demonstrates manually vectorizing a 1D convolution versus using jax.vmap() to automatically vectorize. jax.vmap() is applied with in_axes and out_axes specifications.

jax.jit performs just-in-time compilation as discussed in …/02-jitting.md. jax.make_jaxpr() views the jaxpr representation. jax.jit caches compiled functions but recompilation can occur.

Random number generation uses PRNG keys that must be split instead of global state as covered in …/05-random-numbers.md.

Pytrees are used to represent nested data structures, using functions like jax.tree_map() from …/05.1-pytrees.md.

jax.pmap() performs data parallelism across devices as shown in …/06-parallelism.md. It splits data using in_axes and communicates across devices.

Stateful computations are handled by passing state in/out of functions as arguments. This is demonstrated on a counter and linear regression example in …/07-state.md.

Debugging Guide

References: docs/debugging

JAX provides several tools and techniques for debugging programs.

jax.debug.print allows printing values inside JIT-compiled functions for inspection. It works similarly to regular print but can handle transformations. It reveals how computations are evaluated under different transformations.

jax.debug.breakpoint pauses execution and allows inspecting values in the debugger when hit. It utilizes jax.debug.callback underneath. It can detect nan/inf values using jax.lax.cond.

jax.experimental.checkify allows adding runtime error checks to JAX code. checkify.checkify transforms functions to return an error value instead of raising exceptions, making them compatible with transformations. checkify.check is used to add checks.

Setting jax_debug_nans detects NaNs in JIT-compiled code by raising exceptions. It precisely detects the primitive that produced the NaN.

Setting jax_disable_jit disables JIT compilation, enabling standard Python debugging tools. However, it slows things down and is incompatible with jax.pmap and jax.pjit.

checkify.checkify "functionalizes" the effects of checks so they become boolean operations merged into the returned error value. This "discharges" effects so checked functions behave like pure functions. It threads the error through the function and lowers checks to XLA primitives. Checks are added using checkify.check, which takes an error message format string. Automatic checks can also be selected. Mapped checked functions return a mapped error containing per-element errors.

Pallas Documentation

References: docs/pallas

Pallas allows writing GPU/TPU kernels by extending JAX with Ref types, new primitives, and pallas_call. Kernels use Refs instead of JAX arrays, loading from and storing to memory via pallas.load and pallas.store. pallas_call maps the kernel over an iteration space specified by grid and Specs.

Ref types represent mutable memory kernels can read from and write to. pallas.load/pallas.store simplify loading from and storing to Refs. pallas_call executes the kernel by mapping it over the iteration space defined by grid and Specs. Specs like BlockSpec allow transforming inputs/outputs per grid index, allowing "carving up" large arrays into blocks that fit in fast on-chip memory.

The business logic revolves around pallas_call and Refs. Kernels operate on Refs instead of returning values, writing outputs to provided Refs. pallas_call lifts the kernel into a JAX computation by mapping it over an iteration space defined by out_shape, grid, and Specs.

BlockSpecs play an important role, specifying block shapes and mapping of programs to blocks for inputs/outputs. This allows partitioning work across blocks that fit in fast memory. pallas_call passes the appropriate block Refs to each kernel invocation.

In matmul, BlockSpecs express inputs/outputs as block matrices. Each program computes one output block, implementing a recursive strategy. program_id allows the same kernel to be executed on different data blocks.

Kernels can support fused operations by passing functions as arguments. And pallas_call composes with jax.vmap due to the functional nature.