Mutable.ai logoAuto Wiki by Mutable.ai

flash-attention

Auto-generated from Dao-AILab/flash-attention by Mutable.ai Auto Wiki

flash-attention
GitHub Repository
DeveloperDao-AILab
Written inPython
Stars8.9k
Watchers 96
Created2022-05-19
Last updated2024-01-08
LicenseBSD 3-Clause "New" or "Revised"
RepositoryDao-AILab/flash-attention
Auto Wiki
Generated at2024-01-08
Generated fromCommit abbc13
Version0.0.4

The Flash Attention repository provides an optimized framework and implementations for efficient attention mechanisms used in deep learning models like transformers. It focuses on optimizing memory usage, throughput, and latency of attention relative to standard PyTorch implementations.

At the core, functions like flash_attn_func in …/__init__.py implement the basic multi-head attention computation. Variants handle different input formats like packed queries/keys/values for efficiency. flash_attn_with_kvcache supports caching keys and values. The core attention algorithms rely on matrix multiplications, softmax calculations, and specialized CUDA kernels.

On top of this, the …/modules directory contains common transformer building blocks like multi-head attention and MLP layers. The …/layers directory implements specialized layers for vision and language models.

For low-level optimizations, …/ops leverages the Triton compiler for efficient CUDA kernels. The csrc directory contains C++/CUDA extensions implementing components like attention and losses.

Transformer models for vision and language tasks are implemented in …/models, with utilities in …/utils supporting distributed training and benchmarking.

The tests directory contains comprehensive unit testing of layers, losses, full models, and ops relative to reference implementations. Testing helps ensure optimizations do not affect model accuracy.

The training directory provides a full training loop implementation and utilities like data loading, metrics tracking, model definition, optimization strategies and distributed training. This supports end-to-end optimized model training.

In summary, Flash Attention provides building blocks, utilities and infrastructure focused on optimizing transformer attention mechanisms and models for efficiency across metrics like speed, memory usage, and power consumption.

Optimized Attention and Layers

References: flash_attn, flash_attn/layers, flash_attn/losses, flash_attn/modules, flash_attn/ops, flash_attn/utils

The flash_attn directory implements the core attention computations and optimizations used in the Flash Attention framework. It contains several important components for building efficient transformer models.

The …/ops directory contains highly optimized operations implemented using the Triton compiler. It provides kernels, modules, and utilities for common deep learning building blocks like attention, normalization, and linear layers.

The …/modules directory contains Python modules that implement common model building blocks like multi-head attention, embeddings, and MLPs.

The …/losses directory provides key loss functions for training, including the cross entropy loss.

The …/utils directory contains modules for tasks like benchmarking, distributed training, and loading pretrained models.

Attention Mechanisms

References: flash_attn, flash_attn/flash_attn_interface.py

The file …/flash_attn_interface.py implements core attention computations through PyTorch autograd functions and classes. Functions wrap the application of these classes and kernels to handle different attention patterns. Variable length attention is supported through indexing tensors with sequence lengths.

Model Building Blocks

References: flash_attn/modules

This section contains common building blocks used for constructing transformer models, including multi-head attention and multi-layer perceptrons (MLPs).

The file …/mha.py contains classes for attention.

MLP building blocks are defined in …/mlp.py. The file contains classes for sequential MLPs.

The …/embedding.py file defines classes for embeddings.

Layers

References: flash_attn/layers

The …/layers directory contains implementations of common layers used in both vision and language models. The two primary layers implemented are patch embedding and rotary position embeddings.

Patch embedding takes an input image and divides it into non-overlapping patches. The class in …/patch_embed.py handles this functionality. It takes the image and projects each patch into an embedding space using. It performs input validation and calculates properties like the number of patches. It can optionally apply normalization and flatten the patches.

Rotary position embeddings are used to encode positional information for self-attention. The …/rotary.py file contains functions and classes for applying rotary embeddings. The function applies the embeddings to query, key and value tensors for self-attention. It uses trigonometric encodings of positional indices to generate embedding vectors. The function rotates tensors based on their relative positions, allowing the embeddings to be translation equivariant. Classes like handle applying the embeddings inplace for efficiency.

Losses

References: flash_attn/losses

The …/losses directory contains loss functions used for training models in Flash Attention. The main loss implemented is defined in the …/cross_entropy.py file.

This file handles computing the loss between model outputs and targets. It takes hyperparameters like ignore_index and reduction type. The loss computation is implemented in the file. Computing the loss is abstracted so models can easily use this functionality during training.

Low-Level Optimizations

References: flash_attn/ops, flash_attn/ops/triton

The Triton compiler is used to implement highly optimized CUDA kernels in the …/triton directory.

The file …/layer_norm.py provides layer normalization and RMS normalization with reference CPU implementations and parallel GPU kernels. It also includes a fused layer norm and linear module optimized via TensorRT.

Efficient linear and linear+activation operations are defined in …/linear.py. The kernels are optimized by partitioning inputs into blocks and tiling computations.

A fused dense-ReLU-dense module implemented in …/mlp.py uses kernels with options to checkpoint activations for better performance versus memory tradeoffs.

Utilities

References: flash_attn/utils

This section covers utilities that support tasks like distributed training and benchmarking. The …/utils directory contains several modules for this purpose.

The …/benchmark.py module contains functions for benchmarking and profiling PyTorch models.

The …/generation.py module contains utilities for text generation with Transformer models.

The …/distributed.py module contains utilities for distributed training using PyTorch distributed.

The …/pretrained.py module loads pretrained model weights from local files or HuggingFace Hub.

Models

References: examples, flash_attn/models, tests/models

The …/models directory contains Python modules that define models for various natural language processing and computer vision tasks using the Flash Attention framework. Models are implemented for both language and vision tasks in order to evaluate the performance of the Flash Attention library across different domains. The modules also provide functionality for mapping between different model formats to enable loading pretrained weights from various frameworks.

The …/bert.py file implements BERT models using the HuggingFace library. It leverages padding functions in …/bert_padding.py for efficient self-attention.

The …/gpt.py file defines classes.

The …/gpt_neox.py file contains functions for mapping between the GPT-Neox model format used by Anthropic and the HuggingFace GPT2 format.

The …/vit.py file defines a Vision Transformer model based on the paper "An Image is Worth 16x16 Words". It contains utilities like initialization and stochastic depth.

Vision Models

References: flash_attn/models/vit.py

The file …/vit.py defines vision transformer models for computer vision tasks. It contains functionality for image classification using a Vision Transformer architecture.

The model takes an input image and first embeds it to obtain pixel embeddings. It then processes these embeddings with a series of transformer encoder blocks made up of residual connections and layer normalization as described in the paper. After processing, a classification token is pooled from the output and used with a classification head to obtain predictions.

Some techniques implemented include residual connections between encoder layers and layer normalization before each layer. The model supports various options for patch size, embedding dimension, number of layers, and more as configurable hyperparameters. It can be trained end-to-end for vision classification tasks.

Language Models

References: flash_attn/models/bert.py, flash_attn/models/gpt.py, flash_attn/models/gpt_neox.py

This section defines models for natural language tasks using the Flash Attention framework. Key models implemented include BERT and GPT. The code also handles conversion between different model formats.

The …/bert.py file contains utilities for initializing weights, loading pretrained models, and remapping state dicts between frameworks. Padding is handled in …/bert_padding.py and blocks use configurable options like fused ops.

The …/gpt.py file defines models. Blocks are defined using functions.

The …/gpt_neox.py file provides functions for mapping weights and configuration between formats. Functions perform mappings for layers, embeddings, and more. Functions allow using models with HuggingFace tools.

Testing

References: tests/models

The …/models directory contains comprehensive unit tests for various transformer models implemented using the Flash Attention library. The tests validate that the Flash Attention model implementations match the original HuggingFace Transformers implementations in terms of state dictionaries, forward pass outputs, generation outputs and performance.

Key classes tested include models defined in files like …/test_gptj.py. Functions such as those defined within test files are used to test functionality.

Tests are defined for each model, testing files like …/test_gpt.py which contains tests for models. These test functions take parameters to test configurations. Within each test, important steps include:

  • Loading implementations
  • Initializing models and remapping
  • Running inputs and comparing outputs
  • Generating sequences and comparing scores
  • Testing parallelization across GPUs
  • Checking optimized performance

The tests aim to validate optimizations retain functionality, accuracy and outputs. Comprehensive testing helps ensure optimizations introduce no regressions.

Benchmarking

References: benchmarks, flash_attn/utils

This section benchmarks and profiles the performance of Flash Attention compared to alternative attention implementations. It aims to evaluate both runtime and memory efficiency.

The benchmarks directory contains code to benchmark different attention types and mechanisms. It includes benchmarking attention computations relevant to the Flash Attention framework.

The …/benchmark_flash_attention.py file benchmarks key attention operations against reference PyTorch implementations. It defines test input tensors and runs the operations, timing the results. This allows direct comparison of performance.

The …/benchmark_causal.py file benchmarks causal attention types for language modeling. It implements several attention functions and runs benchmarks across hyperparameters like batch size and sequence length. This profiles efficiency for different problem sizes.

The …/benchmark.py module contains utilities for benchmarking models. Functions time the forward and backward passes of any user-defined function. This supports detailed profiling of models and layers.

Benchmarking Implementations

References: benchmarks, benchmarks/benchmark_flash_attention.py

The benchmarks directory contains code to benchmark different attention implementations against each other. It aims to compare the performance of approaches like PyTorch attention to Flash Attention.

The …/benchmark_flash_attention.py file benchmarks attention computations used in Flash Attention. This file saves results to files in benchmarks for analysis.

By benchmarking core attention computations in Flash Attention against their PyTorch counterparts, we can analyze the performance improvements of optimizations used in Flash Attention. This allows us to validate that Flash Attention provides more efficient attention compared to the standard PyTorch implementation.

Utilities

References: flash_attn/utils/benchmark.py

The …/benchmark.py file contains various utility functions for benchmarking and profiling PyTorch models and layers.

The main functions include:

Wrap
Copy
def benchmark_forward(*inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs):
Wrap
Copy
def benchmark_backward(*inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs):
Wrap
Copy
def benchmark_combined(*inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs): 
Wrap
Copy
def benchmark_fwd_bwd(*inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs):
Wrap
Copy
def pytorch_profiler(*inputs, trace_filename=None, backward=False, amp=False, amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
Wrap
Copy
def benchmark_memory(*inputs, desc="", verbose=True, **kwinputs):

These functions time how long different passes take, allowing detailed benchmarking of models and layers. Automatic mixed precision is supported. Benchmarking results are returned or can be printed.

Distributed Training

References: flash_attn/utils/distributed.py, tests/modules

The …/distributed.py file contains utilities for distributed training using PyTorch distributed. It includes functions for operations without autograd support but allow for asynchronous execution.

The tests in …/modules validate distributed training functionality. The directory includes tests in …/test_block_parallel.py which partitions a block containing layers across devices. …/test_embedding_parallel.py partitions dimensions of an embedding layer across devices. …/test_mlp_parallel.py partitions layers of a multi-layer perceptron across devices.

Each test file initializes distributed training and model parallelism across different world sizes, generates input data and partitions it across processes, and initializes weights from a base model. They run and compare forward and backward passes on parallel and non-parallel models.

Model Performance

References: examples, tests/models

The Benchmarking section discusses benchmarking models on downstream tasks. The examples and …/models directories contain examples and tests for evaluating full Flash Attention models on tasks.

Some key functionality includes:

  • The examples directory contains training and inference examples for models like BERT, GPT-2, and Vision Transformer. Examples are run on downstream tasks to evaluate performance.

  • Unit tests in …/models validate Flash Attention implementations of models like BERT, GPT-2, OPT match HuggingFace in terms of outputs and accuracy on tasks.

  • Tests like in …/test_gpt_generation_parallel.py evaluate generation performance of parallelized models against baselines.

  • Benchmarking tests in files like …/test_btlm.py and …/test_gpt.py compare optimized Flash Attention models to HuggingFace in FP16 and FP32, on metrics like generation speed and forward pass latency.

Distributed Training

References: flash_attn/utils, tests/modules, training, training/src/distributed

The core utilities provided for distributed training are handled in the …/distributed directory. This directory contains code for efficient gradient handling during distributed data parallel training.

The main component handles gradient compression and passing. Gradients are stored in a lower precision format, divided by the number of processes, and passed between processes. A callback then restores the original precision.

By handling the operations asynchronously, no additional synchronization steps are required. Storing gradients in a lower precision also reduces communication costs without losing information after passing between processes.

The …/utils directory contains additional utilities to support distributed training.

Distributed Data Parallel

References: flash_attn/utils/distributed.py, tests/modules

The …/distributed.py file contains utilities for distributed data parallel training across multiple GPUs. It includes a class that wraps raw distributed functions to support autograd during the forward and backward passes.

The tests in …/modules validate distributed data parallel implementations by comparing outputs and gradients of parallel models to non-parallel baselines. Files partition components across devices using classes initialized from base models. They scatter inputs and partition gradients.

Gradient Compression

References: training/src/distributed

The …/distributed directory provides functionality for efficient gradient compression during distributed data parallel training. It implements gradient compression using a class defined in the …/ddp_comm_hooks.py file.

The class casts gradients to half precision format to compress the gradients. It then performs an asynchronous allreduce operation on the compressed gradients across the process group. A callback is registered to restore the gradients to their original precision after the allreduce completes. This avoids additional synchronization steps during gradient compression.

The key steps implemented are:

  • Gradients in a tensor are divided by the world size and cast for compression using fused operations.

  • An asynchronous allreduce is performed on the compressed gradients.

  • A callback registered with the future returned by allreduce restores the gradients to their original precision after communication.

By compressing gradients, communication costs are significantly reduced without losing information after the allreduce. This provides an efficient approach to distributed training at large scale.

Multi-Node Training

References: training/src/utils

The …/utils directory contains several utilities to support model parallel and multi-node distributed training.

The …/distributed.py module provides key functionality for multi-node training. It contains wrappers for initializing distributed backends and reducing values across workers. This allows tasks like averaging metrics and gradients across nodes.

The …/ddp_zero1.py and …/ddp_zero2.py modules implement different strategies for distributed data parallel training. Both classes override PyTorch Lightning's default distributed data parallel approach. This enables using optimizers like that shard optimizer states across nodes. Checkpoint saving and loading is modified to separately save model and optimizer states indexed by rank.

The …/gpu_affinity.py module provides utilities for controlling CPU affinity of GPU devices. It can set the affinity of GPU processes to optimize for either data parallel or model parallel training strategies.

Low Level Optimizations

References: csrc, flash_attn/ops

The csrc directory implements key components like attention and losses using highly optimized CUDA/C++. It provides low-level optimizations that improve the performance and efficiency of deep learning models.

The …/flash_attn subdirectory contains the core C++ implementation of the Flash Attention mechanism. The …/flash_fwd_kernel.h and …/flash_bwd_kernel.h files define CUDA kernels for the forward and backward passes of attention. These kernels leverage optimizations like CUTLASS primitives for MMA and shared memory usage to efficiently parallelize attention across GPU threads and blocks.

The …/flash_fwd_launch_template.h and …/flash_bwd_launch_template.h files contain templates for launching the attention kernels. The templates select the optimal kernel based on hardware capabilities and properties of the input, such as sequence length and number of heads.

The …/kernel_traits.h and related files define structs containing traits for the attention kernels. The traits abstract away low-level details like data types, shared memory layouts, and thread block configurations. This allows implementing kernels agnostically to hardware specifics.

The …/layer_norm subdirectory provides a CUDA extension for fused dropout, residual, and layer normalization operations. The …/ln_api.cpp file implements the C++ API, handling input/output conversion and dispatching to kernels.

The …/xentropy subdirectory contains optimized CUDA implementations for calculating the cross-entropy loss function and its gradients in parallel. This allows efficient training of models.

The …/fused_dense_lib and …/fused_softmax subdirectories provide CUDA kernels for fused linear, softmax, and activation operations. Fusion improves performance over separate computation.

Overall, the csrc code implements low-level optimizations that accelerate key components, enabling efficient deep learning via optimized CUDA/C++ implementations.

Optimized Kernels

References: csrc/flash_attn, csrc/flash_attn/src

The kernels in the FLASH attention model are highly optimized to efficiently perform the core attention computation on GPUs. Key optimizations implemented include matrix multiplication acceleration (MMA), usage of shared memory, and prefetching of tensors.

The …/flash_fwd_kernel.h and …/flash_bwd_kernel.h files define templates for CUDA kernels that handle the forward and backward passes of a single query-key-value block. These kernels take tensors from global memory and load them into shared memory. They then perform the attention computation using MMA operations provided by CUTLASS, accumulating results into an output tensor in shared memory. Softmax and scaling are applied over multiple iterations. Finally, the results are written back to global memory.

The kernels support features like causal masking, dropout, splitting queries/keys/values across blocks, and appending new keys/values. Optimizations like reusing shared memory and prefetching tensors from global memory are implemented. Tensor layouts and copying between shared and global memory are abstracted using classes defined in files like …/kernel_traits.h and …/kernel_traits_sm90.h. These classes define element types, accumulation types, shared memory layouts, global memory copy structures, and shared memory sizes to optimize memory usage.

The …/flash_fwd_launch_template.h and …/flash_bwd_launch_template.h files contain templates for launching the different Flash Attention kernels on CUDA. They select the optimal kernel based on properties like head dimension size, hardware capabilities, and problem size/shape. The kernels are specialized via templates and boolean switches to efficiently handle different configurations with a minimal number of variants. Shared memory usage is tuned for each hardware architecture to maximize occupancy. Kernels are launched with grids and blocks optimized for the specific hardware and problem size.

Utilities defined in …/utils.h provide functions to support MMA operations, tensor copies between shared and global memory, type conversions, and tensor reductions.

Hardware Specialization

References: csrc/flash_attn, csrc/flash_attn/src

The …/src directory implements hardware-specific optimizations for the Flash Attention kernels through template metaprogramming. Key files like …/kernel_traits_sm90.h define traits for kernel parameters, data types, and memory layouts. These traits are templated on kernel dimensions and data types, allowing the implementations to specialize for different hardware configurations.

The …/kernel_traits_sm90.h file makes heavy use of CUTLASS and CUTE types to define memory layouts, copy operations, and optimizations like MMA and shared memory for NVIDIA GPUs. It provides types for element sizes, accumulation, indexing, and tensor layouts optimized for the target hardware. These abstract away low-level details and allow portability across devices.

Kernel launch templates in files like …/flash_fwd_launch_template.h select the optimal kernel variant and configuration based on compile-time traits.

The …/generate_kernels.py file programmatically generates all kernel variants by combining hyperparameters like SM architecture and data types. This allows targeting multiple hardware configurations with a minimum number of implementations.

Overall, the hardware-specific implementations are achieved through C++ templates and template metaprogramming to specialize kernels at compile-time for traits like data types and thread block sizes. This enables high performance by optimizing for target hardware characteristics.

Memory Optimizations

References: csrc/flash_attn, csrc/flash_attn/src

The …/kernel_traits_sm90.h file defines structs containing traits for CUDA kernels targeting SM90 architectures. These structs define element, accumulation, and index types optimized for techniques like MMA and shared memory usage. Types are also defined for copying between global and shared memory.

The structs allow abstracting kernel details so they can be implemented against the trait interfaces. Templating the structs on kernel parameters makes the traits configurable.

The …/kernel_traits_sm90.h file utilizes CUTLASS primitives to define optimized memory layouts and copy operations for techniques such as MMA, shared memory usage, and asynchronous global memory copies.

Triton Compiler

References: flash_attn/ops, flash_attn/ops/triton

The directory …/triton implements highly optimized operations for machine learning using the Triton compiler. It provides kernels, modules, and utilities that take advantage of Triton's capabilities for common deep learning building blocks.

Many core operations are implemented via Triton kernels that are optimized for parallelism and performance. For example, the file …/linear.py contains kernels for linear layers. These kernels partition computations into blocks that are processed in parallel by threads to maximize throughput. Optional bias addition and activation functions are fused into the kernels for lower latency. Autotuning finds configurations that optimize the kernels for different hardware.

Normalization layers like layer normalization and RMS normalization are implemented in …/layer_norm.py. This file provides reference CPU implementations along with parallel GPU kernels that leverage TensorRT primitives. The kernels process rows of the input tensor in parallel.

The file …/cross_entropy.py contains highly optimized CUDA kernels for calculating the cross entropy loss function and its gradients using TensorRT. These kernels parallelize the loss calculation across threads and blocks. Techniques like splitting computation into blocks and all-gathering support large vocabularies.

Modules provide common deep learning primitives through high-performance Triton-compiled implementations.

C++/CUDA Extensions

References: csrc, csrc/fused_dense_lib, csrc/fused_softmax, csrc/layer_norm, csrc/rotary, csrc/xentropy

The csrc directory contains C++/CUDA source code implementing optimized GPU operations and extensions for deep learning models. Several key extensions are provided:

The …/fused_dense_lib directory contains kernels implementing fused linear and activation operations.

The …/fused_softmax directory implements fused and masked softmax kernels. The …/scaled_masked_softmax.h file defines kernels performing softmax within warps by partitioning inputs across blocks and warps, loading to shared memory, and computing in parallel.

The …/layer_norm directory provides a CUDA extension for fused dropout, residual, and layer normalization.

The …/rotary directory implements position embedding kernels. The kernels broadcast embeddings across threads to multiply inputs by cosine and sine values in parallel.

Comprehensive Testing

References: tests

The tests directory contains comprehensive unit tests that validate model accuracy, outputs, and performance relative to references like HuggingFace implementations. It tests models, layers, losses, and operations in depth through rigorous validation of key algorithms, components, and implementations.

Model testing covers models like BERT, GPT-2, BART and ensures they match HuggingFace in terms of state dictionaries, forward pass outputs, generation outputs, and performance. Tests validate accuracy, outputs, and efficiency gains through optimizations.

Layers are tested by modules in …/modules which partition components across GPUs using tensor model parallelism. They initialize parallelism, partition inputs, and compare forward and backward passes to non-parallel baselines. This validates correct partitioning and equivalent behavior.

Losses are tested by files in …/losses which implement functions to apply losses to random inputs and check outputs and gradients match PyTorch references under different configurations parameterized by dtype, smoothing, and more.

Operations are validated by modules containing extensive tests for classes implementing ops. For example, …/test_fused_dense.py tests classes against references by initializing weights, adding residuals, and comparing outputs and gradients with relaxed tolerances. …/test_layer_norm.py applies functions to random data and ensures outputs match references across batch sizes and dtypes.

Model Testing

References: tests/models

The …/models directory contains comprehensive unit tests for various transformer models implemented using the Flash Attention library. These tests validate that the Flash Attention model implementations match the original HuggingFace Transformers implementations in terms of state dictionaries, forward pass outputs, generation outputs and performance.

Tests are defined for each model in its own test file, such as …/test_bert.py and …/test_gpt.py. These test functions take hyperparameters as parameters to test different configurations. Models are initialized from pretrained checkpoints, and weights are copied between implementations using remapping functions. Models are run in evaluation mode on sample inputs, and the outputs are compared directly between implementations to validate equivalence. Generation is tested by comparing generation scores. Differences are printed and assertions check thresholds are met.

Some key aspects tested across files include:

  • Loading models from checkpoints and remapping state dicts for use in Flash Attention
  • Running forward passes through the Flash Attention model and comparing outputs to HuggingFace
  • Generation using Flash Attention and comparing scores to HuggingFace
  • Testing generation with optimizations like CUDA graph caching
  • Checking state dictionary keys and shapes match between implementations
  • Comparing outputs when models are run in parallel across multiple GPUs
  • Parametrizing tests over different model sizes and settings

Layers Testing

References: tests/layers, tests/modules

The …/layers directory contains rigorous tests for the layers implemented in Flash Attention. These tests validate that layers like multi-head attention and MLPs produce the expected outputs and gradients.

The …/test_rotary.py file contains tests for rotary position embeddings. It tests two implementations of rotary embeddings - one applied separately before attention, and another with embeddings interleaved between attention heads. Random inputs are generated and the embeddings applied using a class that generates the sine and cosine embedding weights based on the dimension, and applies them inplace via extraction and concatenation, or interleaving for the alternative implementation.

Tests in …/modules validate parallelizing key components across multiple GPUs. This includes partitioning an entire block containing attention and feedforward layers. Weights are copied from a non-parallel base model and partitioned across devices. Inputs are scattered and outputs/gradients compared with relaxed tolerances.

Losses Testing

References: tests/losses

The …/losses directory contains comprehensive tests for loss function implementations in Flash. It tests both serial and parallel versions of the cross entropy loss.

The …/test_cross_entropy.py file tests a custom loss implementation. It compares the implementation to PyTorch's loss class across different configurations like label smoothing and logit scaling. For each configuration, it checks that the forward passes and gradients match closely between the two implementations.

The file implements the core cross entropy loss computation in its forward pass. It takes the logits and target labels as input. It computes the log probabilities, applies any scaling parameters, and computes the CE loss. It optionally adds an LSE term. The backward pass computes and returns gradients with respect to the logits tensor, supporting inplace computation.

The …/test_cross_entropy_parallel.py file tests a parallelized version of the cross entropy loss. It constructs a parallel loss function and compares its forward and backward passes to the non-parallel version across various configurations.

These tests thoroughly validate the key loss computation algorithms and ensure Flash's loss implementations are correct. They also test parallelization functionality.

Operations Testing

References: tests/ops, tests/ops/triton

The …/triton subdirectory contains tests for operations implemented in the Triton inference server. Tests are parameterized over configurations to thoroughly validate implementations.

The file …/test_fused_dense.py contains tests for operations. The class is tested across different data types, configurations, and is checked for equivalence to a reference implementation. Gradient equivalence is also validated.

The …/test_dropout_layer_norm.py file tests classes, which implement operations. Extensive unit tests validate the operations under various configurations by checking forward passes against references and asserting gradients match.

Training Loop and Utilities

References: training, training/src

The core functionality provided by the code under the section Training Loop and Utilities includes implementing the full training loop, loading and preprocessing datasets, tracking metrics, defining models, handling optimization, and providing various utilities to support the training process.

The …/src directory contains the source code implementing these aspects.

The …/tasks subdirectory contains classes that define the actual training logic. These classes configure components and define the core training loop.

Data loading uses the data modules in …/datamodules like for language modeling tasks. These leverage libraries for text preprocessing.

Metrics are implemented as classes in …/metrics. Metrics track important values.

Models are defined as classes in …/models. Common components live in …/modules.

Optimization uses …/optim for parameter grouping and custom optimizers.

Utilities for checkpointing, profiling, distributed training are in …/utils.

The training function, in …/train.py, initializes the above components from the config and sets up distributed training if needed. Evaluation leverages similar abstractions.

Callbacks defined in …/callbacks hook into training to monitor metrics, save checkpoints, and integrate functionality. These provide a clean way to modify training behavior without changing core logic.

Data Loading

References: training/src/datamodules, training/src/datamodules/datasets

The …/datamodules directory contains functionality for loading and preprocessing datasets.

Datasets are loaded and preprocessed using classes defined in relevant files. Language modeling datasets are handled by functions in …/detokenizer.py for detokenization and …/lm_dataset.py for efficiently loading tokenized sequences.

ImageNet datasets are prepared and loaded from a specified path in …/imagenet.py. Standard transforms and preprocessing are applied.

Data is loaded in batches using PyTorch DataLoaders with optional distributed loading. Fault-tolerant sampling is supported by functionality in …/fault_tolerant_sampler.py.

The validation dataset for ImageNet is optionally cached using a dataset class to speed up validation. Mixup regularization is supported by …/timm_mixup.py.

Metrics Tracking

References: training/src/metrics

The …/metrics directory contains implementations of metrics that are tracked during model training. The main classes defined are:

  • The class in …/accuracy.py handles computing accuracy from both hard class labels and soft target probabilities. It overrides the method to check for floating point targets, and takes the along the last dimension if the targets are floating to convert them to class indices before computing accuracy.

  • The class in …/num_tokens.py provides a running count of the total number of tokens seen during training. It stores the count as a state variable using, so the count persists across epochs. It overrides the base method to preserve the count between epochs rather than resetting it to 0.

  • The class in …/perplexity.py calculates perplexity for language models. It inherits from the base class and overrides methods to accumulate the weighted loss total and count. It handles ignoring padding indices in loss calculation. To calculate perplexity, it takes the exponential of the average loss divided by the count.

All classes leverage functionality from the base class, with important implementation choices like overriding methods and using to store persistent counts. These provide clean implementations of common training metrics.

Model Definition

References: training/src/models, training/src/models/modules

The …/models directory contains the source code for models used in sequence modeling tasks. The …/modules subdirectory defines common components for sequence modeling through Python modules.

The main file is …/seq_common.py, which contains functions that provide common building blocks for tasks like classification and language modeling.

Optimization

References: training/src/optim

The …/optim directory contains functionality for parameter grouping and customized optimizers to support model optimization in PyTorch. Parameter grouping separates the model's parameters into different groups for the optimizer based on criteria like weight decay. This is implemented by the function in …/param_grouping.py.

The function takes the model and optimizer configuration as input. It identifies parameters that should and should not have weight decay applied based on attributes like whether they are biases or in certain normalization modules. It separates parameters into three sets - those that will experience weight decay, those that won't, and those with special hyperparameters. It validates all parameters were assigned to a set before returning them grouped into optimizer parameter groups, with decayed groups having weight decay and others not. Any with special hyperparameters get their own group.

Customized optimizers are defined in Python files in …/optim. For example, the file subclasses two classes to enable seamless use of learning rate scheduling without needing to pass an argument each time is called. It overrides to increment an internal attribute counting progress unless an argument is passed, in which case it substitutes the internal counter. This allows substitution of the internal counter when is called without an argument.

Utilities

References: training/src/utils

This section covers utilities used to support training tasks like checkpointing, profiling, and distributed training. The …/utils directory contains several important modules for these tasks.

The …/checkpoint.py module provides utilities for loading PyTorch checkpoints and modifying the state dictionaries.

The …/flops.py module contains functions for profiling models to calculate metrics like floating point operations (FLOPs) and activations counts. It can profile models using either DeepSpeed or FVCore profiling libraries if available.

The …/distributed.py module implements utilities for distributed training like initializing distributed backends and reducing values across workers. It provides wrappers for common patterns that work in both distributed and non-distributed settings.

The class defined in …/ema.py is used for maintaining an exponentially weighted moving average of model parameters. It initializes shadow parameter copies and provides methods for updating averages, copying averages back to the originals, and saving/loading class state.

Distributed Training

References: training/src/distributed

The …/distributed directory provides functionality for efficient distributed data parallel training of models across multiple GPUs or nodes. It implements gradient compression to reduce communication costs without losing gradient information.

The …/ddp_comm_hooks.py file handles gradient compression during distributed training. It divides gradients by the world size and performs an asynchronous allreduce operation. This compresses the gradients to reduce communication overhead. A callback restores the gradients after allreduce.

The key steps implemented are:

  • Gradients in a tensor are divided by world size for compression.

  • An asynchronous allreduce operation is performed on the compressed gradients across the process group.

  • A callback restores the gradients after completion.

This provides an efficient approach to distributed training without additional synchronization steps during gradient compression.

Callbacks

References: training/src/callbacks

The …/callbacks directory contains Python files that define callback classes for PyTorch Lightning. These callbacks hook into the training loop at specific points to monitor metrics, save checkpoints, and log other useful information.

The main callback classes defined monitor parameters, log loss scales for mixed precision training, profile FLOPs, integrate Weights & Biases functionality, implement exponential moving averages, and more.

The callback defined in …/params_log.py logs counts of the total, trainable, and non-trainable parameters in the model to the trainer's logger during training. It iterates over the model's parameters and sums their sizes based on configurable flags.

The callback in …/loss_scale_monitor.py monitors the loss scale for automatic mixed precision training. It overrides a PyTorch Lightning method to collect loss scale statistics from different sources depending on the trainer strategy and precision plugin. These statistics are logged to all loggers associated with the trainer.

The callback in …/ema.py implements exponential moving averages of the model parameters during training. It initializes an object from …/ema.py using the model's parameters. A method on this object is called periodically to update the averages. The callback handles saving/restoring the original model parameters and object state during validation and checkpointing.