flash-attention
Auto-generated from Dao-AILab/flash-attention by Mutable.ai Auto WikiRevise
flash-attention | |
---|---|
GitHub Repository | |
Developer | Dao-AILab |
Written in | Python |
Stars | 10k |
Watchers | 103 |
Created | 05/19/2022 |
Last updated | 04/03/2024 |
License | BSD 3-Clause "New" or "Revised" |
Repository | Dao-AILab/flash-attention |
Auto Wiki | |
Revision | |
Software Version | 0.0.8Basic |
Generated from | Commit 23e8fa |
Generated at | 04/03/2024 |
The flash-attention
repository provides a highly optimized and memory-efficient implementation of the attention mechanism, a core component of transformer-based models. This library is designed to leverage the capabilities of modern GPU architectures to achieve significant performance improvements over standard attention implementations.
The most important parts of the repository are the core implementation of the Flash Attention algorithm, the optimized transformer-based model implementations, and the collection of highly efficient operations used throughout the library.
The flash-attention/csrc/flash_attn
directory contains the core implementation of the Flash Attention algorithm, including the CUDA kernel functions for the forward and backward passes. These kernels are designed to handle various features like dropout, causality, and local attention, and are parameterized to work efficiently across different data types, head dimensions, and CUDA compute capabilities. The flash.h
, flash_fwd_kernel.h
, and flash_bwd_kernel.h
files define the core functionality of the Flash Attention algorithm.
The flash-attention/flash_attn/models
directory provides implementations of various transformer-based language models, such as BERT and GPT, which leverage the optimized attention and other operations from the flash-attention
library. These model implementations demonstrate how the core components can be integrated into larger machine learning pipelines.
The flash-attention/flash_attn/ops
and flash-attention/csrc
directories contain a collection of highly optimized and flexible operations, including Triton-Accelerated Operations, Fused Dense and MLP Layers, and Layer Normalization and RMS Normalization. These operations are crucial for achieving high performance in transformer-based models and are designed to be easily integrated into various machine learning frameworks.
The key algorithms and technologies used in the flash-attention
repository include:
- Flash Attention: A highly optimized attention mechanism that leverages the capabilities of modern GPU architectures to achieve significant performance improvements over standard attention implementations.
- Triton: A domain-specific language and compiler for writing efficient CUDA kernels, used extensively throughout the library to implement various neural network operations.
- Tensor Parallelism: The library utilizes tensor parallelism to distribute the computation of certain operations, such as linear layers and MLPs, across multiple GPUs.
The main design choices of the flash-attention
repository include:
- Modular and Extensible Architecture: The codebase is organized into several directories, each focusing on a specific aspect of the library's functionality, making it easier to maintain and extend the code.
- Extensive Use of Template Metaprogramming: The CUDA kernel implementations make extensive use of template metaprogramming techniques to generate specialized kernels for different data types, head dimensions, and CUDA compute capabilities, improving performance and flexibility.
- Optimization-Focused: The library prioritizes performance and efficiency, with a strong focus on leveraging the capabilities of modern GPU architectures through the use of techniques like Triton-accelerated kernels, fused operations, and tensor parallelism.
Core Attention MechanismRevise
References: csrc/flash_attn
, csrc/flash_attn/src
The core functionality of the Flash Attention algorithm is implemented in the …/
directory and the …/
directory. This includes the forward and backward pass kernels, attention masking, and rotary position encoding.
Attention Kernel ImplementationsRevise
References: flash-attention
The core CUDA kernel implementations for the forward and backward passes of the Flash Attention algorithm are located in the …/flash_attn
directory. These kernels provide highly optimized implementations of the attention mechanism, with support for various features like dropout, causality, and local attention.
Attention Masking and EncodingRevise
References: flash-attention
The …/flash_attn
directory contains the core implementation of the Flash Attention algorithm, which includes functionality for applying various types of attention masks and performing rotary position encoding on input tensors.
Attention Mechanism ConfigurationRevise
References: flash-attention
The core of the Flash Attention algorithm's efficient GPU implementation is managed through a set of helper structs and functions defined in the …/flash_attn
directory. These components handle the complex memory layouts, tiling, and parameterization required to achieve high performance on modern GPU architectures.
Auxiliary ComponentsRevise
References: flash-attention
The …/ops
directory contains a collection of highly optimized and flexible operations that are crucial components of the Flash Attention library. The main functionality in this directory includes:
Transformer-based ModelsRevise
References: flash_attn/models
The Flash Attention repository provides implementations of various transformer-based language models, including BERT and GPT, as well as specialized variants.
BERT ModelRevise
References: flash_attn/models/bert.py
The BertEncoder
class is responsible for passing the input embeddings through the BERT transformer blocks. It uses the create_block()
function to create the individual transformer blocks, which consist of a multi-head attention (MHA) module and a multi-layer perceptron (MLP) module. The forward()
method of BertEncoder
handles the case where key_padding_mask
is provided, which allows for efficient computation by unpadding the input and processing the sequence in a batched manner. If subset_mask
is provided, the encoder will only compute the last layer output for the subset of tokens specified by the mask, which is useful for the BERT pre-training task.
GPT ModelRevise
References: flash_attn/models/gpt.py
The GPTModel
class in the …/gpt.py
file provides the core implementation of the GPT (Generative Pre-trained Transformer) language model. This model is a key component of the Flash Attention library, which focuses on optimizing attention mechanisms for modern GPU architectures.
Other Transformer-based ModelsRevise
The Flash Attention repository includes implementations and utility functions for several other transformer-based language models, such as GPT-NeoX, GPT-J, and Vision Transformer (ViT).
Optimized OperationsRevise
References: flash_attn/ops
, csrc/fused_dense_lib
, csrc/fused_softmax
, csrc/layer_norm
, csrc/rotary
The Flash Attention repository provides a collection of highly optimized and flexible operations, including layer normalization, linear and MLP layers, and their tensor-parallel implementations.
Triton-Accelerated OperationsRevise
References: flash_attn/ops/triton
The Triton-Accelerated Operations subsection covers the highly optimized Triton-accelerated operations used in the Flash Attention library, including cross-entropy loss, activation functions, layer normalization, linear layers, and multi-layer perceptrons.
Fused Dense and MLP LayersRevise
References: flash_attn/ops/fused_dense.py
The flash-attention/flash_attn/ops/fused_dense.py
file contains highly optimized implementations of linear layers and multi-layer perceptrons (MLPs) for CUDA devices. These implementations focus on performance and support tensor parallelism with sequence parallelism.
Layer Normalization and RMS NormalizationRevise
References: flash_attn/ops/layer_norm.py
, flash_attn/ops/rms_norm.py
The layer_norm.py
and rms_norm.py
files in the …/
directory contain efficient implementations of layer normalization and root-mean-square (RMS) normalization, with support for dropout, residual connections, and parallel processing.
Rotary Positional EmbeddingsRevise
References: flash_attn/ops/triton/rotary.py
The flash-attention/flash_attn/ops/triton/rotary.py
file contains a Triton kernel implementation and a Python function for applying rotary positional embeddings to input tensors. Rotary positional embeddings are a type of positional encoding that can be used in transformer-based models to incorporate positional information without increasing the model size.
Utility ComponentsRevise
References: flash_attn/utils
, training/src/utils
The Utility Components section of the wiki covers a variety of utility functions and classes that support various aspects of the Flash Attention library, including benchmarking, distributed training, text generation, and pre-trained model loading.
Benchmarking UtilitiesRevise
References: flash_attn/utils/benchmark.py
The benchmark.py
file in the …/
directory provides a set of utilities for benchmarking the performance of PyTorch functions, including forward and backward pass timing, automatic mixed precision (AMP) support, and memory usage profiling.
Distributed Training UtilitiesRevise
References: flash_attn/utils/distributed.py
The distributed.py
file in the …/
directory provides a set of utility functions and classes for distributed training using PyTorch's distributed package. The main functionalities include:
Text Generation UtilitiesRevise
References: flash_attn/utils/generation.py
The flash-attention/flash_attn/utils/generation.py
file provides a set of utilities for efficient text generation using transformer models. The key functionality includes:
Pre-trained Model Loading UtilitiesRevise
References: flash_attn/utils/pretrained.py
The primary functionality for loading pre-trained model weights is provided by the state_dict_from_pretrained()
function in the …/pretrained.py
file. This function is responsible for loading pre-trained model weights from various sources, including local files and the Hugging Face Hub.
Checkpoint ManagementRevise
References: training/src/utils/checkpoint.py
The checkpoint.py
file in the …/
directory provides utility functions for loading and manipulating model checkpoints. The main functionality includes:
Distributed Training SupportRevise
The DDPStrategyZero1
class in …/ddp_zero1.py
is an extension of the DDPStrategy
class from PyTorch Lightning. It is responsible for handling the state management of the ZeroRedundancyOptimizer
in a distributed training setup.
Distributed Training UtilitiesRevise
References: training/src/utils/distributed.py
The …/distributed.py
file provides utilities for initializing and managing the distributed backend, synchronizing workers, and performing distributed operations.
Exponential Moving AverageRevise
References: training/src/utils/ema.py
The ExponentialMovingAverage
class, located in …/ema.py
, provides utilities for maintaining and updating the exponential moving average (EMA) of model parameters. This is useful for model validation and inference, as the EMA parameters can be used to evaluate the model without affecting the original optimization process.
Model ProfilingRevise
References: training/src/utils/flops.py
The flops.py
file in the …/
directory provides utilities for profiling the computational complexity of PyTorch models. It offers two main approaches for model profiling:
GPU Affinity ManagementRevise
References: training/src/utils/gpu_affinity.py
The gpu_affinity.py
file provides utilities for managing the CPU affinity of the process to the CPU cores associated with the specified GPU device(s). This is important for optimizing performance and ensuring efficient utilization of system resources.
Logging and Configuration UtilitiesRevise
References: training/src/utils/utils.py
The …/utils.py
file provides a set of utility functions and classes that are crucial for the overall functionality of the project. These utilities focus on logging, handling optional configurations, printing project configurations, and ensuring proper cleanup, particularly for the Weights & Biases (WandB) logger.
Sequence Modeling PipelineRevise
References: training/src
The core functionality for the sequence modeling pipeline in the Flash Attention project is provided by the …/src
directory. This directory contains the main SequenceModel
class, which handles the overall sequence modeling workflow, and the SequenceLMModel
class, which is a specialized version for language modeling tasks. The directory also includes various utility modules and classes for data preprocessing, model optimization, distributed training, and model evaluation.
Data PreprocessingRevise
References: training/src/datamodules
The …/datamodules
directory contains several key components that provide functionality for working with various datasets, including language modeling datasets, the ImageNet dataset, and support for data augmentation techniques like Mixup.
Model OptimizationRevise
References: training/src/optim
The …/optim
directory contains the key components responsible for model optimization in the Flash Attention library.
Distributed TrainingRevise
References: training/src/distributed
, training/src/utils/ddp_zero1.py
, training/src/utils/ddp_zero2.py
, training/src/utils/distributed.py
The key components in the distributed training utilities are:
Model EvaluationRevise
References: training/src/metrics
, training/src/eval.py
, training/src/train.py
The …/metrics
directory provides a set of custom PyTorch metric classes that extend the functionality of the torchmetrics
library to handle specific use cases, such as Mixup data augmentation and language model evaluation.