Mutable.ai logoAuto Wiki by Mutable.ai

whisper

Auto-generated from openai/whisper by Mutable.ai Auto Wiki

whisper
GitHub Repository
Developeropenai
Written inPython
Stars53k
Watchers 460
Created2022-09-16
Last updated2023-12-31
LicenseMIT
Repositoryopenai/whisper
Auto Wiki
Generated at2023-12-31
Generated fromCommit ba3f3c
Version0.0.4

Whisper is an open-source speech recognition and text-to-speech library developed by Meta AI. It implements end-to-end deep learning models for transcribing audio to text and synthesizing speech from text.

The main functionality is located in the whisper directory, which contains modules for audio processing, model architectures, decoding, and tokenization. The pretrained models allow high-quality speech recognition and translation across many languages.

Key algorithms used include convolutional neural networks in the encoder to process audio spectrograms, Transformer models with self-attention in the decoder to generate text, and beam search for decoding audio to text efficiently. The models are trained on over 600,000 hours of multilingual speech data.

Some notable design choices are the use of a shared multilingual model architecture that can handle many languages, relying on attention mechanisms instead of RNNs, and joint training on multiple speech tasks like recognition, translation, and identification. The model can be used via Python for inference or via the command line interface.

The load_model() function in …/__init__.py is the main entry point, loading a pretrained Whisper model by name for use in Python. The loaded model's transcribe() method handles audio preprocessing, decoding, and formatting results.

Unit tests in tests validate core functionality like audio loading, tokenization, text normalization, and end-to-end transcription without exposing internal implementation details.

Key preprocessing steps occur in …/audio.py, including loading audio with load_audio() and generating mel spectrograms with log_mel_spectrogram().

The core model architecture is defined in …/model.py, including the AudioEncoder, TextDecoder, and full Whisper model. The encoder processes spectrograms while the decoder generates text using cross-attention.

Decoding from audio to text is handled in …/decoding.py using classes like DecodingTask and beam search in BeamSearchDecoder.

Text normalization occurs in …/normalizers to clean and preprocess text for the model.

Speech Recognition

References: whisper, tests

The core functionality for transcribing audio into text is handled by the DecodingTask class defined in …/decoding.py. DecodingTask manages the full decoding process by initializing important components like the Tokenizer, Inference implementation, and TokenDecoder based on the provided DecodingOptions.

The DecodingTask class handles preprocessing steps like language identification or applying a prompt/prefix. Its _main_loop() method runs the core autoregressive decoding, applying the TokenDecoder and any LogitFilters at each step.

The Inference interface defines how models are forwarded and handles installing caching hooks. PyTorchInference implements caching of keys/values from the decoder blocks.

Classes like GreedyDecoder and BeamSearchDecoder implement different decoding algorithms by extending TokenDecoder. They select the next token during decoding.

LogitFilter classes apply rules to the logits like suppressing blank tokens or timestamps that don't follow patterns.

The result is formatted into a DecodingResult and returned. Types are used throughout to ensure correctness. This implements a flexible but optimized decoding pipeline.

The Audio Processing, Model Architecture, and Tokenization sections will discuss how audio is preprocessed, models are defined, and text is tokenized, which are important preprocessing steps for decoding.

The Testing section validates the decoding functionality works as intended.

Audio Processing

References: whisper/audio.py

The core audio preprocessing functionality in Whisper is handled by functions in the …/audio.py file. This file contains utilities for loading raw audio, preprocessing it, and converting it to mel spectrograms which serve as the model input.

The load_audio function provides a consistent way to open and preprocess raw audio files across different platforms. It handles downmixing audio to mono and resampling all files to 16kHz sample rate. This ensures a standardized audio representation for further processing.

Audio samples may be of varying lengths, so the pad_or_trim function pads or trims each sample to N_SAMPLES frames. This produces audio of a fixed size expected by the encoder models. It supports NumPy arrays and PyTorch tensors to accommodate different use cases.

The mel_filters function precomputes Mel filterbank matrices rather than relying on an external library like librosa at runtime. This improves efficiency by allowing filtering to be done via matrix multiplication. It caches the filters to avoid recomputing them.

A key step is the log_mel_spectrogram function, which takes a preprocessed audio tensor and computes the short-time Fourier transform. It then projects this to a Mel spectrogram using the precomputed filters. The output is a log Mel spectrogram tensor, with optional padding applied to standardize the shape. This processed tensor serves as the model input representation of the audio signal.

Model Architecture

References: whisper/model.py

The Model class defines the overall architecture of the Whisper speech recognition model. It combines the AudioEncoder and TextDecoder through cross-attention to map input audio to predicted text.

The AudioEncoder handles encoding variable-length input spectrograms into a fixed-dimensional audio embedding. It consists of a stack of convolutional layers followed by residual attention blocks implemented in the ResidualAttentionBlock class. The convolutions compress the spectral dimensions of the input while preserving temporal information. The attention blocks then model long-range dependencies in the time dimension.

The TextDecoder predicts output text tokens based on the encoded audio representation. It contains a similar stack of ResidualAttentionBlocks, but these blocks perform cross-attention between the decoder queries and encoded audio keys/values. At each timestep, the decoder predicts the next token based on its own self-attention and the context from the audio encoder.

Both the encoder and decoder make use of the MultiHeadAttention module for their self-attention and cross-attention. This module implements scaled dot-product attention with multiple heads to allow the model to jointly attend to information from different representation subspaces. For efficiency, it caches computed attention keys and values for reuse between positions.

The full Model ties together the encoder and decoder. It handles model dimensions and embeddings, passes inputs through the encoder, and passes the encoded audio to the decoder for text prediction. This end-to-end speech recognition pipeline allows the model to directly map speech inputs to text predictions.

Decoding

References: whisper/decoding.py

The DecodingTask class manages the full decoding process. It initializes important components like the Tokenizer, Inference implementation, and TokenDecoder based on DecodingOptions.

The _main_loop() method runs the core autoregressive decoding, applying the TokenDecoder and LogitFilters at each step. It handles preprocessing like language identification or applying a prompt/prefix.

The Inference interface defines model forwarding and handles caching installed hooks. The PyTorchInference implementation handles caching keys/values from the model's decoder blocks.

The TokenDecoder interface defines how the next token is selected. Classes like GreedyDecoder and BeamSearchDecoder implement different decoding algorithms. GreedyDecoder selects the token with highest logit at each step while BeamSearchDecoder maintains a beam of candidate sequences.

LogitFilter classes apply rules to the logits like suppressing blank tokens or timestamps that don't follow patterns. This helps constrain the search space during decoding.

The result is formatted into a DecodingResult and returned. Types are used throughout to ensure correctness and optimizations like FP16 and caching improve speed.

The code implements a flexible but optimized decoding pipeline to transcribe audio into text using trained models. It handles tasks like language identification, audio encoding, constrained decoding search, and result formatting in a modular way.

Tokenization

References: whisper/tokenizer.py

The Tokenizer class handles converting raw text into discrete tokens for model input and output. It serves as the main entry point for tokenization functionality. Tokenizer initializes with special token IDs and provides methods like encode(), decode(), and decode_with_timestamps() which delegate the core work to an underlying Encoding instance.

Cached properties like language_token and non_speech_tokens extract commonly needed values from the Encoding for optimized performance of common operations. Methods such as split_to_word_tokens() preprocess text into word tokens according to language-specific conventions.

The Encoding class represents the vocabulary and tokenization rules for a particular language or model. Its constructor builds a rank-frequency mapping of tokens and mappings of special tokens.

The get_encoding() factory function retrieves the correct Encoding based on a model name parameter. It loads the vocabulary file, constructs the required mappings, and returns a configured Encoding instance.

get_tokenizer() acts as the main entry point, calling get_encoding() to retrieve the Encoding and configuring a Tokenizer instance. It handles language and task settings, ultimately returning a Tokenizer to the caller.

The Encoding class encapsulates the vocabulary data through attributes like the rank-frequency map and special token mappings. Its encode() and decode() methods handle the core encoding and decoding work by looking up tokens in these internal maps. The Tokenizer delegates most of the work to Encoding while focusing on a user-friendly interface and caching commonly needed values.

Text Normalization

References: whisper/normalizers

The BasicTextNormalizer class handles basic text cleaning tasks like removing symbols and diacritics. It initializes with options to control which cleaning steps to apply, such as remove_diacritics. The __call__ method then applies these cleaning functions to input strings in the defined order.

Key cleaning functions include remove_symbols_and_diacritics() which iterates over characters, checking Unicode categories to replace unwanted characters. Additional diacritic mappings are defined in ADDITIONAL_DIACRITICS to normalize characters not covered by Unicode. The class provides a consistent interface for basic text preprocessing.

The EnglishNumberNormalizer class contains the core number parsing logic. It initializes dictionaries mapping number words to values like self.ones and self.tens. The process_words() method iterates through split input, using these mappings to concatenate numeric words while properly handling suffixes, fractions, and currencies. This converts numeric strings to normalized Arabic numerals for model input.

The EnglishTextNormalizer handles additional English-specific preprocessing. It has dictionaries like self.replacers to expand contractions. The __call__ method applies preprocessing steps in order: lowercasing, expansion, number normalization with EnglishNumberNormalizer, spelling normalization, and cleanup. This fully normalizes English text.

Testing

References: tests

The tests directory contains comprehensive unit tests that validate the core functionality of the whisper library. Tests are implemented using the PyTest framework and cover a variety of inputs through parameterization. This ensures the library's key algorithms and components meet specifications across different conditions.

The tests can be grouped into several categories based on the functionality they validate:

Audio Processing: The …/test_audio.py file contains tests for audio loading and preprocessing functionality. It validates that functions like load_audio() and log_mel_spectrogram() produce the correct output and meet expected properties when called directly or on file paths. This confirms audio I/O and feature extraction are working as intended.

Text Normalization: Tests in …/test_normalizer.py validate the text normalization classes like EnglishNumberNormalizer and EnglishSpellingNormalizer. They thoroughly exercise the classes with sample inputs and outputs, verifying the underlying algorithms for tasks like number parsing and spelling correction are correctly implemented.

Timing Operations: The …/test_timing.py file tests common timing functions. It compares the outputs of dtw_cpu(), dtw_cuda(), and median_filter() to reference implementations, and checks the CPU and GPU versions of functions are equivalent. This ensures the timing logic produces accurate and consistent results.

Tokenization: Tests in …/test_tokenizer.py exercise the basic encoding, decoding, and token splitting operations of tokenizers for different languages. This validates the expected tokenization functionality is provided.

Transcription: The …/test_transcribe.py file loads a pre-trained model and tests end-to-end transcription of audio. It asserts properties of the returned transcription text and timing information to ensure the core speech recognition capability works properly.

Unit tests are defined across these files to validate the key components without relying on implementation details. Fixtures defined in …/conftest.py provide shared test functionality and deterministic randomness. Together, the comprehensive suite of tests help ensure the whisper library meets its design specifications.

Text Normalization

References: whisper/normalizers

The main classes for text normalization are BasicTextNormalizer and EnglishTextNormalizer. BasicTextNormalizer handles basic cleaning tasks like lowercasing, removing symbols and diacritics. It is defined in …/basic.py. EnglishTextNormalizer performs additional normalization steps for English text such as expanding contractions and standardizing numbers and spellings. It is defined across …/english.py and subclasses implemented there.

BasicTextNormalizer initializes with options like whether to remove diacritics. Its __call__ method handles the main cleaning workflow, applying functions specified in initialization like remove_symbols_and_diacritics(). This function iterates over characters, checking Unicode categories to replace or keep characters.

EnglishNumberNormalizer contains the core number parsing logic. It initializes dictionaries mapping number words to values like self.ones. Method process_words() iterates words, concatenating numbers using the mappings while handling suffixes, fractions and currencies.

EnglishSpellingNormalizer applies a British-American spelling mapping loaded from …/english.py to each word.

EnglishTextNormalizer standardizes text by applying preprocessing steps in order, including those handled by EnglishNumberNormalizer and EnglishSpellingNormalizer.

The main interfaces are BasicTextNormalizer for basic cleaning and EnglishTextNormalizer which coordinates additional normalization types for English. EnglishNumberNormalizer implements the key number parsing logic through initialization mappings and text iteration.

Audio Processing

References: whisper

The core audio preprocessing functionality is encapsulated in the …/audio.py module. This module provides utilities for loading raw audio files, resampling to a standard sample rate, and converting audio clips to log mel spectrograms for use as model input.

The load_audio function handles opening audio files in a cross-platform way using FFmpeg. It downmixes stereo files to mono and resamples the audio to 16 kHz to produce a consistent representation. This ensures the audio data is standardized before further processing.

The log_mel_spectrogram function converts audio clips to spectrograms. It first computes the short-time Fourier transform (STFT) of each windowed frame of audio. It then projects the STFT onto a mel scale filterbank using precomputed Mel filter weights loaded by mel_filters. Applying the log function produces the final log mel spectrogram tensor. This tensor can then be used directly as input for audio models.

Key aspects of the implementation include downmixing and resampling audio for consistency, precomputing the Mel filter weights for efficiency, and using PyTorch tensors and modules for flexibility and integration with deep learning models. The standardized log mel spectrogram representation facilitates training audio models on large datasets.

Audio Loading

References: whisper/audio.py

The load_audio function handles loading audio files and handling various formats. It supports loading audio from a variety of sources including paths to files, file-like objects, and NumPy arrays. Under the hood it uses FFmpeg to open the audio, downmix to mono if needed, and resample the sample rate to 16kHz. This provides a consistent audio representation for later processing regardless of the original format or properties.

load_audio handles the key audio loading logic. It first checks if the audio source is already a NumPy array, and if so just returns it directly. Otherwise it calls out to FFmpeg to do the loading, downmixing, and resampling. Downmixing ensures a consistent mono signal, while resampling to 16kHz standardizes the sample rate. The resampled PCM audio is returned as a NumPy array. By encapsulating the FFmpeg functionality, load_audio provides an easy and cross-platform way to open audio files in Python.

Spectrogram Generation

References: whisper/audio.py

The log_mel_spectrogram function in …/audio.py handles converting audio to mel spectrograms for model input. It takes a raw audio source as input, either a file path or NumPy array. The Short-Time Fourier Transform (STFT) is computed on the audio using a Hann window with 50% overlap. This converts the signal from the time domain to the frequency domain.

The mel filterbank matrices cached by mel_filters are applied to project the STFT coefficients onto the mel scale. This simulates the response of the human auditory system, where perception of frequency is non-linear and resolution decreases with higher frequencies. The mel filters are stored as NumPy arrays to avoid relying on librosa at runtime for improved performance.

A log transform is then applied to compress the dynamic range of the mel spectrogram. This matches the perception of the human ear more closely while also improving numerical stability for downstream tasks. Finally, the tensor is returned in the desired device and padded if needed to match the expected input size of the encoder models.

This processing pipeline encapsulates all of the standard steps to convert time-domain audio into a frequency-domain representation optimized for speech recognition. The mel spectrogram acts as a compact intermediate representation of the audio content, capturing important characteristics for modeling while discarding phase information unimportant for speech tasks.

Audio Padding

References: whisper/audio.py

The pad_or_trim function in …/audio.py handles padding or trimming audio/spectrogram tensors to a fixed expected length. This is important for the encoder models which expect a standard input size. It supports both NumPy arrays and PyTorch tensors.

The function takes an audio tensor as input and pads or trims it using zeros to match the N_SAMPLES constant defined in the file. This represents the number of frames the encoder models can accept as input.

Padding is done by calculating the amount to pad on either side of the tensor, then using PyTorch's F.pad or NumPy's pad functions. These add zeros to the beginning and end of the tensor dimensions to extend it to the expected length.

Trimming is done by slicing the tensor to only keep values from index 0 to N_SAMPLES. This chops off any excess frames beyond what the models can process.

The function handles both 1D audio tensors and 3D spectrogram tensors by padding the time dimension. This produces a consistent tensor shape that can be passed directly to the encoder.

It also supports data in NumPy or PyTorch format by using the appropriate functions like F.pad or np.pad. This allows preprocessing pipelines to seamlessly work with either backend.

The pad_or_trim function is a key part of preparing variable length audio samples for model ingestion. By normalizing lengths it enables batching for efficient training and inference.

Model Architecture

References: whisper

The …/model.py file defines the core model architecture for the Whisper speech recognition system. It contains classes that implement the standard encoder-decoder structure used in many sequence-to-sequence models.

The ModelDimensions dataclass defines important hyperparameters like the number of encoder/decoder layers and attention heads. The MultiHeadAttention class implements the multi-head dot product attention that serves as the basic building block used throughout the model.

The ResidualAttentionBlock class represents a single residual attention block, containing MultiHeadAttention and feed-forward layers with residual connections. These blocks are arranged in stacks to form the encoder and decoder.

The AudioEncoder class contains a front-end stack of 1D convolutional layers followed by residual attention blocks. It encodes input mel spectrograms into high-level representations.

The TextDecoder class decodes the encoded audio representations using a stack of residual attention blocks. These blocks utilize cross-attention between the decoder inputs and encoded audio via another MultiHeadAttention layer to predict output text tokens autoregressively.

The full Whisper model combines the AudioEncoder and TextDecoder into a single end-to-end speech recognition model. It handles moving modules to the correct device and passing data between the encoder and decoder during inference.

Some key implementation details include using mixed precision in the LayerNorm and Linear subclasses, caching attention keys/values for efficiency in MultiHeadAttention, and generating sinusoidal position encodings.

Encoder

References: whisper/model.py

The AudioEncoder module is responsible for encoding input mel spectrograms into a latent representation that can be decoded into text. It uses convolutional and self-attention layers to process the spectrogram inputs.

The core component is the ResidualAttentionBlock, which is a building block that applies a multi-head attention mechanism and residual connection. Multiple ResidualAttentionBlocks make up the AudioEncoder, with earlier layers using 1D convolutions to process local spectrogram features before passing to self-attention.

The MultiHeadAttention module is used within each ResidualAttentionBlock. It implements scaled dot-product attention, projecting queries, keys and values into multiple heads that operate in parallel. The keys and values are cached between positions for efficiency using the install_kv_cache_hooks() function.

Sinusoidal position embeddings generated by sinusoids() are added to the input representations within each block. This allows the model to utilize positional information without learned embeddings.

Dtype casting for mixed precision is handled by the LayerNorm, Linear and Conv1d subclasses defined in the file. This enables the model to operate on tensors with different numeric types like float16 and float32.

The encoded audio representation generated by the AudioEncoder is then passed to the TextDecoder to generate text predictions through cross-attention between the encoded audio and partial decoding outputs.

Decoder

References: whisper/model.py

The TextDecoder is responsible for decoding the encoded audio representations from the AudioEncoder and generating text predictions. It takes the encoded audio as input and attends to it using multi-head cross attention in its residual attention blocks.

The core component of the TextDecoder is the ResidualAttentionBlock, which implements the self-attention and cross-attention sublayers. At each timestep, it performs self-attention over the previously generated tokens to incorporate context, and cross-attention to the encoded audio to focus on relevant parts of the input.

The cross-attention weights are computed using MultiHeadAttention between the text queries and audio keys/values. The attention heads learn to focus on different temporal regions to extract complementary information.

After cross attention, a feedforward layer generates the initial text predictions which are projected to the size of the text vocabulary. A softmax produces the final token probabilities.

At inference time, greedy decoding is used to find the most likely token sequences.

Full Model

References: whisper/model.py

The full Whisper model combines the AudioEncoder and TextDecoder classes to perform speech-to-text. The AudioEncoder encodes input mel spectrograms into high-level audio representations using convolutional and self-attention layers defined in ResidualAttentionBlock.

The encoded audio is passed to the TextDecoder, which attends to the encoder outputs using cross-attention. This allows information about the input audio to be incorporated when generating text predictions. The TextDecoder contains its own ResidualAttentionBlock layers that process both the encoded audio and previously generated tokens.

At initialization, the Whisper model simply constructs the encoder and decoder modules. It defines the forward pass, which runs the input audio through the encoder to obtain context vectors, passes them to the decoder, and generates token predictions.

Key details:

  • The encoder and decoder are connected through cross-attention in the decoder blocks.
  • Sinusoidal positional encodings are used instead of learned embeddings.
  • Language identification is supported by checking the size of the text embedding matrix.
  • Mixed precision is handled by LayerNorm, Linear, Conv1d subclasses.
  • Attention keys/values are cached between positions for efficiency.

Decoding

References: whisper

The core decoding logic is handled by the DecodingTask class defined in …/decoding.py. DecodingTask manages the overall decoding process by initializing important components like the Tokenizer, Inference implementation, and TokenDecoder based on the provided DecodingOptions.

The main responsibilities of DecodingTask include:

The Inference interface defines how model forwarding is handled during decoding. PyTorchInference handles caching of attention keys/values between positions for efficiency by installing hooks on the model's decoder blocks.

TokenDecoder implementations like GreedyDecoder and BeamSearchDecoder define the algorithms for selecting the next most likely token during decoding. GreedyDecoder picks the single most probable token while BeamSearchDecoder maintains a beam of candidate sequences.

LogitFilter classes apply rules to filter or penalize certain tokens predicted by the model, like suppressing blank tokens or timestamps that don't follow patterns. This helps improve consistency.

The file implements all major components for flexible yet efficient decoding, from preprocessing audio to the core decoding loop managed by DecodingTask with optimized model forwarding via Inference.

Token Decoding

References: whisper/decoding.py

The DecodingTask class manages token decoding by initializing important components during decoding like the Tokenizer, Inference implementation for model forwarding, and TokenDecoder for selecting the next token.

The _main_loop() method in DecodingTask runs the core autoregressive decoding loop. At each step, it uses the TokenDecoder to select the next token.

The TokenDecoder interface defines how the next token is selected. Classes like GreedyDecoder and BeamSearchDecoder implement different decoding algorithms. GreedyDecoder implements greedy decoding by always selecting the token with highest logit score. BeamSearchDecoder maintains a beam of candidate sequences and advances them in parallel at each step.

LogitFilter classes can also be applied during decoding. These modify the logits before the next token is selected, such as suppressing blank tokens or tokens unlikely based on timing. This allows inserting priors and improving accuracy.

Language Model

References: whisper

The language model is responsible for predicting the next most likely token during decoding. It utilizes a Transformer decoder model trained on speech data to generate token sequences that correspond to the input audio.

The core component is the TextDecoder class defined in …/model.py. The TextDecoder contains a stack of residual attention blocks that incorporate context from the encoded audio representations. At each timestep, it generates predictions for the next token id by attending to the encoded audio using its multi-head attention modules.

The forward() method handles the core autoregressive generation process. It takes the encoded audio and previous predicted tokens as input, passes them through the residual attention blocks, and outputs logits over the vocabulary at each timestep.

A key aspect is that the MultiHeadAttention module caches attention keys and values between positions for efficiency. This avoids recomputing them at each timestep during decoding.

The Whisper model class defined in …/model.py combines the AudioEncoder and TextDecoder into a single PyTorch module. Its forward() method runs the full encode-decode process.

During inference, the DecodingTask class in …/decoding.py manages running the Whisper model autoregressively. Classes like GreedyDecoder implement algorithms for selecting the next predicted token at each timestep based on the logits.

Greedy Decoding

References: whisper

Greedy decoding sequentially predicts the most likely token at each timestep of the audio, without searching for the best overall sequence. The TokenDecoder interface defines how the next token is selected during decoding. The GreedyDecoder class implements greedy decoding by selecting the token with the highest predicted probability at each step.

The GreedyDecoder is initialized with options like an end-of-sequence token. Its decode() method handles the core decoding loop. For each timestep, it gets the logits from the model using Inference, finds the argmax, and checks for the end token. It caches intermediate results for efficiency.

Some key aspects of greedy decoding include:

  • It makes locally optimal choices at each step without searching for the best overall sequence.
  • This makes it fast but can produce suboptimal results compared to search algorithms like beam search.
  • GreedyDecoder implements this by selecting the single most probable token from the model predictions.
  • Caching intermediate results in decode() improves efficiency.

The DecodingTask manages the full decoding process including initializing important components like the Tokenizer, Inference and TokenDecoder implementations. It handles preprocessing steps and calls _main_loop() which runs the core autoregressive decoding loop, applying the configured TokenDecoder at each step.

Speech Recognition

References: tests

The core functionality of speech recognition in whisper is transcribing audio into text. This is handled through several key components:

The …/audio.py module contains functions for loading audio files and preprocessing them. load_audio() reads audio files and returns a numpy array. log_mel_spectrogram() performs a short-time Fourier transform, maps frequencies to mel scale, and takes the log to convert to decibels, generating mel spectrograms for model input.

The …/model.py module defines the encoder, decoder, and full model architecture. The AudioEncoder class encodes input spectrograms. The TextDecoder class decodes the encoder output and predicts output text tokens. The full Whisper model combines these components.

Transcription is handled by transcribing audio into text.

The …/test_transcribe.py file loads a pre-trained model and tests the transcription of audio, validating properties of the returned transcription text and timing information. This ensures the end-to-end speech recognition capability works as expected.

Unit tests in tests extensively validate core functionality like audio processing and model architecture. They provide a comprehensive test suite to ensure whisper meets specifications without relying on implementation details.

Audio Processing

References: whisper

The …/audio.py file contains utilities for loading audio files and converting them to mel spectrograms. The load_audio() function opens audio files and handles resampling and downmixing to produce a standardized audio representation.

The core preprocessing is done by the log_mel_spectrogram() function. It takes the loaded audio and first computes the spectrogram. It then projects the spectrogram onto the mel scale using precomputed Mel filterbank matrices loaded from mel_filters(). This projects the frequencies into mel bands. Finally, it takes the log of each value to convert the mel spectrogram to decibels. This compressed representation is then returned and can be used as input to the audio encoder models.

The pad_or_trim() function ensures any mel spectrograms have a fixed expected length required by the encoder models. It handles padding or trimming the spectrogram arrays to the standard number of frames defined in the code. This provides consistent length input to the models.

Model Architecture

References: whisper

The Whisper model follows an encoder-decoder architecture for speech recognition. The AudioEncoder encodes variable-length input spectrograms into fixed-dimensional contextual representations. It contains convolutional and self-attention layers arranged in an encoder stack to process the spectral inputs.

The TextDecoder generates predicted text tokens autoregressively. It contains a similar stack of self-attention layers with a cross-attention layer to relate the encoded audio representations to predictions at each timestep.

The Whisper class represents the full model, containing the AudioEncoder and TextDecoder. It handles model initialization, forwarding batches of spectrogram inputs through the encoder and decoder, and returning predictions.

The AudioEncoder uses Conv1d modules with activations to process spectrogram frames. It contains ResidualAttentionBlock modules arranged in an encoder stack. These blocks utilize MultiHeadAttention to relate different parts of the input sequence.

The TextDecoder contains a stack of ResidualAttentionBlock modules like the encoder. But its blocks include MultiHeadAttention to attend to the encoded audio representations. This cross-attention allows mapping from encoded audio to predicted text tokens at each timestep.

The MultiHeadAttention module implements the core attention mechanism. It projects queries, keys and values, calculates attention scores, and returns weighted sums of values. The ResidualAttentionBlock combines a multi-head attention with feed-forward layers and a residual connection.

Decoding

References: whisper

The core decoding process is handled by the DecodingTask class in …/decoding.py. DecodingTask manages initializing important components like the Tokenizer, Inference implementation, TokenDecoder and any LogitFilters based on provided DecodingOptions.

The main steps are:

  • Audio preprocessing into mel spectrograms with …/audio.py
  • Language identification with detect_language()
  • Model forwarding and caching during decoding with Inference
  • Core autoregressive decoding loop with TokenDecoder applying decoding algorithms like beam search
  • Applying LogitFilter rules to tokens during decoding

DecodingTask handles preprocessing like language identification or applying a prompt/prefix. Its _main_loop() method runs the core autoregressive decoding, applying the TokenDecoder and LogitFilters at each step.

The Inference interface defines model forwarding, with PyTorchInference caching keys/values from the decoder blocks.

TokenDecoder defines how the next token is selected, with classes like GreedyDecoder and BeamSearchDecoder implementing different algorithms.

LogitFilter classes apply rules to the logits like suppressing blank tokens or timestamps that don't follow patterns.

Heuristics help improve consistency, like no-speech detection, hallucination filtering, and prompt resetting between windows.

Token Decoding

References: whisper/decoding.py

The DecodingTask class manages token decoding by initializing important components during decoding like the Tokenizer, Inference implementation for model forwarding, and TokenDecoder for selecting the next token. It handles preprocessing steps before decoding begins.

The _main_loop() method runs the core autoregressive decoding loop, applying the TokenDecoder at each step to select the next token. This is done by forwarding the audio through the model to get logits, then applying any LogitFilters to modify the logits before selecting the next token.

The Inference interface handles model forwarding and caches outputs to improve speed. The PyTorchInference implementation handles caching keys and values from the decoder blocks.

Key classes involved in token decoding include:

Language Model

References: whisper

The language model predicts the next token in the sequence. It utilizes a Transformer architecture similar to the encoder, but processes the encoder outputs with cross-attention to generate text predictions at each timestep. The model follows an auto-regressive approach, using its previous predictions as inputs to generate the next tokens.

The core component is the TextDecoder defined in …/model.py. It contains a stack of residual attention blocks like the encoder, with the addition of cross-attention to the encoded audio representations from the encoder using MultiHeadAttention. This cross-attention allows the model to map from encoded audio to the text space at each timestep.

Inside the residual attention blocks, MultiHeadAttention is used to relate textual query embeddings to key-value pairs from previous blocks in the decoder stack (self-attention) as well as key-value pairs from the final encoder block (cross-attention). The attention weights are used to aggregate contextual information that is fed to the subsequent layers.

Some key aspects of the implementation include mixed precision support via custom LayerNorm and Linear subclasses, and caching of attention keys/values for efficiency via MultiHeadAttention. The TextDecoder contains no learned embeddings, instead positional encodings are generated with sinusoids() and fed to the initial decoder block.

The full model combines the AudioEncoder and TextDecoder into a single Whisper class defined in …/model.py. Its forward() method handles passing inputs through the encoder and decoder to generate predictions. The model follows a standard encoder-decoder paradigm for sequence-to-sequence problems like speech recognition.

Beam Search

References: whisper

The BeamSearchDecoder class implements beam search decoding to find the most likely transcription of audio given a pre-trained Whisper model. At each timestep, it maintains the top-K scoring partial hypotheses and expands each by one token. The scores are calculated based on the log probabilities from the model added to any heuristic scores.

Some key aspects of the implementation:

  • Beam search decoding is handled by overriding the decoding method in a subclass of TokenDecoder.

  • It maintains the partial hypotheses in a data structure, which supports efficient operations like retrieving/updating the top scores.

  • Heuristic scores can be added during decoding, such as a duration penalty. This guides the search towards more likely sequences.

  • Caching is used to efficiently retrieve the log probabilities for the entire beam in one pass.

  • Parameters like beam width and length penalty can be tuned to control the breadth and depth of the search.

Greedy Decoding

References: whisper

Greedy decoding sequentially predicts the most likely next token at each step based on the model probabilities, without looking ahead. This approach is implemented in the GreedyDecoder class defined in …/decoding.py.

The GreedyDecoder class inherits from the TokenDecoder interface. It contains a single method that selects the token with highest probability at each step.

The core decoding logic is contained within a for loop. At each step, it gets the logits for all possible next tokens from the model. It then finds the argmax of the logits to determine the most probable next token. This token is returned and also used as the input for the next step.

After decoding an audio chunk, the predictions are post-processed before being returned. Functions detailed in Audio Processing align the predictions to the original audio using attention weights to generate timestamp information for each predicted word.

Additional classes can modify the model outputs during decoding. For example, a filter may suppress blank or repeated tokens to encourage diversity in predictions. This helps improve consistency across output segments.

Tokenization

References: whisper

Tokenization handles converting text into discrete tokens for model input and output. The Tokenizer class provided by whisper.tokenizer handles the core tokenization tasks. It initializes with an Encoding object which defines the vocabulary and tokenization rules for a language.

The Encoding class represents the vocabulary and contains mappings from words to integers. Its constructor builds the rank-frequency mapping from the vocabulary text file and initializes mappings for special tokens.

The main methods on Tokenizer are encode(), decode(), and decode_with_timestamps(). encode() takes raw text and uses the Encoding to convert it to a list of integer ids. decode() performs the reverse conversion from ids to text. decode_with_timestamps() is similar but also extracts character position information to align tokens to timestamps.

Tokenizer has some cached properties like language_token and non_speech_tokens that efficiently extract commonly needed special token values from the underlying Encoding. These optimizations improve performance of common operations.

The split_to_word_tokens() method handles language-specific preprocessing, splitting text into word tokens depending on the language's conventions for word boundaries.

The get_encoding() factory function retrieves a configured Encoding instance based on the model name and vocabulary file. It constructs the necessary mappings, building the core representation of the vocabulary.

Text Normalization

References: whisper

The BasicTextNormalizer class provides basic text cleaning functionality. It initializes with options like whether to remove diacritics, then applies functions to clean strings. The remove_symbols_and_diacritics() function iterates over characters, checking Unicode categories to replace, map, or keep characters based on the category.

The EnglishNumberNormalizer class handles number parsing. It initializes dictionaries mapping number words to values like self.ones and self.tens. The process_words() method iterates through split words, using the mappings to concatenate numbers while handling suffixes, fractions, currencies, etc.

The EnglishSpellingNormalizer applies a loaded British-American spelling mapping to each word.

The EnglishTextNormalizer class standardizes text by applying steps like lowercase, expanding contractions from self.replacers, and delegating number and spelling normalization before cleanup.

Testing

References: whisper

The tests directory contains extensive unit tests for validating core functionality throughout the whisper library. Comprehensive testing is implemented using PyTest and parameterization to cover a wide range of inputs and corner cases.

Key areas that are tested include:

The tests utilize fixtures defined in …/conftest.py for seeding random numbers. They validate outputs against expected values or equivalence with other libraries' implementations like SciPy. Comprehensive parameterization covers many inputs and scenarios.

Now some important implementation details:

The EnglishNumberNormalizer class contains mappings from number words to values in dictionaries like self.ones and self.tens. Its process_words() method iterates over split words, concatenating numbers using the mappings while handling suffixes, fractions, currencies etc. Extensive testing validates the parsing algorithms.

The median_filter() function calls torch.median() to sort rows, matching SciPy's padding. Tests compare results to validate equivalence across implementations.

test_transcribe() loads a model, runs transcription, and asserts properties of the returned text and timing against expectations, validating end-to-end functionality.

test_tokenizer() validates encoding, decoding, splitting operations of classes like Tokenizer and Encoding that encapsulate tokenization logic.

Text Normalization

References: whisper

The whisper library provides several classes for normalizing text. The BasicTextNormalizer class in the …/basic.py file handles basic text cleaning operations like removing symbols, diacritics, and splitting text on letter boundaries. The EnglishTextNormalizer class found in …/english.py performs additional preprocessing of English text.

The EnglishTextNormalizer class standardizes English text using several sub-classes. The EnglishNumberNormalizer class contains algorithms for parsing numeric words like "twenty five" and converting them to standardized Arabic numerals. It initializes mappings from number words to values in dictionaries like self.ones and self.tens. The process_words method iterates over split words, using the mappings to concatenate numbers while handling suffixes, fractions, currencies etc.

The EnglishSpellingNormalizer class applies a loaded British-American spelling mapping to standardize variations like "cancellation" and "cancellation". The EnglishTextNormalizer handles additional preprocessing by applying steps like removing filler words, expanding contractions using the self.replacers dictionary, and delegating number and spelling normalization to other classes before cleanup.

The BasicTextNormalizer initializes with options like whether to remove diacritics. When called on a string, it first applies lowercase, then uses regular expressions and cleaning functions to remove symbols and diacritics. It optionally splits the string on letter boundaries. The cleaning functions iterate characters, replacing, mapping, or keeping them based on Unicode categories.

Audio Processing

References: whisper

The …/audio.py file contains utilities for loading, preprocessing, and converting audio files into spectrograms. It handles common audio processing tasks needed as input for the speech models.

The main functionality includes:

  • Loading audio files with load_audio() which handles resampling and downmixing audio into a standardized format. It uses FFmpeg under the hood.

  • Preprocessing audio with functions like pad_or_trim() which pads or trims waveforms to a fixed expected length for the models.

  • Generating spectrograms with log_mel_spectrogram() which computes the short-time Fourier transform, maps frequencies to mel scale with precomputed filterbanks, and returns the log mel spectrogram in decibels.

Some key implementation details:

  • load_audio() provides a consistent way to open raw audio files across platforms via FFmpeg. It downmixes to mono and resamples to 16kHz sample rate.

  • pad_or_trim() handles padding or trimming waveforms to N_SAMPLES frames expected by the encoder models. It supports NumPy and PyTorch inputs.

  • mel_filters() caches precomputed Mel filterbank matrices, avoiding external library calls during spectrogram generation for improved performance.

  • log_mel_spectrogram() encapsulates the standard processing pipeline, moving data to the given device and optionally padding before returning log mel spectrograms.

Audio Loading

References: whisper

The load_audio() function handles loading audio files and supports various formats. It uses FFmpeg under the hood to open audio streams in a cross-platform manner. load_audio() downmixes the audio to mono if needed and resamples it to 16kHz sample rate, ensuring a consistent representation. It returns a normalized NumPy array containing the raw audio samples.

The key aspects of load_audio() include:

  • Opening audio files via FFmpeg for cross-platform support across different file formats like FLAC, WAV, MP3 etc.

  • Downmixing multi-channel audio to mono since the models expect a single channel

  • Resampling to 16kHz sample rate to match the models' expected input rate

  • Returning a normalized NumPy array containing the audio samples for further processing

Spectrogram Generation

References: whisper

The log_mel_spectrogram() function in the …/audio.py file is used to convert audio clips into mel spectrograms for use as model input. It takes a raw audio clip as input, either as a file path or NumPy array. The audio is resampled to 16kHz if needed. Mel spectrograms are computed from the short-time Fourier transform of the audio. The frequencies from the STFT are mapped to the mel scale using precomputed Mel filterbank matrices loaded by mel_filters(). Taking the logarithm of each value converts the mel spectrogram to decibels. This processed mel spectrogram is returned as a PyTorch tensor, with optional padding applied to a fixed expected length.

The key aspects of implementation include:

  • The mel_filters() function loads filterbanks for 80 or 128 bands to support different model configurations.

  • Padding mel spectrograms to a fixed length with pad_or_trim() ensures consistency for the encoder models.

  • Moving data to the given device and optional padding produces a standardized input tensor.

Audio Padding

References: whisper

The pad_or_trim() function handles padding or trimming audio/spectrogram tensors to a fixed expected length for model input. This function is defined in …/audio.py. It takes an audio tensor and the expected number of frames as arguments.

pad_or_trim() first determines if padding or trimming is needed by comparing the length of the audio to the expected number of frames. It then allocates a new tensor of the correct size to hold the padded or trimmed audio.

For padding, the function pads the ends of the audio tensor with zeros until it reaches the expected length. This is done using PyTorch's F.pad() function with the 'constant' mode and value of 0.

For trimming, it slices the audio tensor to keep only the initial frames up to the expected length using basic tensor slicing.

This ensures all audio samples passed to the model have a consistent fixed length, which is required since the model expects a standard input size. It supports both NumPy arrays and PyTorch tensors as inputs, handling the padding or slicing appropriately based on the framework.

This function plays an important role in audio preprocessing by normalizing the length of all samples. The fixed expected size is likely based on the receptive field sizes and expectations of the convolutional layers in the model. Padding with zeros prevents edge effects from the convolution while keeping the sample rate and dimensions consistent.

Model Architecture

References: whisper

The Whisper class defined in …/model.py represents the full speech recognition model and combines the AudioEncoder and TextDecoder modules. The AudioEncoder uses convolutional layers to process input mel spectrograms, while the TextDecoder attends to the encoded audio representations using cross-attention to generate predicted text tokens.

The AudioEncoder contains Conv1d and ResidualAttentionBlock modules arranged in an encoder stack. The ResidualAttentionBlock utilizes MultiHeadAttention to relate different parts of the input spectrogram.

The TextDecoder also contains a stack of ResidualAttentionBlock modules, except it cross-attends to the encoded audio representations from the encoder using MultiHeadAttention. This allows the decoder to map from encoded audio to predicted text tokens at each timestep.

The MultiHeadAttention module caches attention keys and values for efficiency. The Whisper class initializes the encoder and decoder, handles device placement, and defines the core forward() method. forward() passes a batch of mel spectrograms through the AudioEncoder, then through the TextDecoder to generate token predictions.

Encoder

References: whisper

The AudioEncoder module handles encoding input mel spectrograms into latent representations that can be decoded into text. It contains the AudioEncoder class which implements the convolutional front-end of the model.

The AudioEncoder class contains a stack of Conv1d and ResidualAttentionBlock modules arranged sequentially to operate on the time axis of the input spectrograms. The Conv1d layers apply 1D convolutions to learn filters over the frequency channels.

The ResidualAttentionBlock modules incorporate multi-head attention from MultiHeadAttention to relate different parts of the input sequence. This allows the model to learn long-range dependencies in the audio. The blocks follow a residual connection scheme where the input is added to the output.

Some key aspects of the implementation:

  • Mixed precision support is provided via custom LayerNorm and Linear subclasses that handle casting weights and activations to different dtypes like float16.

  • Caching of attention keys and values between positions is done in MultiHeadAttention for efficiency.

  • Sinusoidal position embeddings are used instead of learned embeddings to generalize better to variable length inputs.

  • Layer normalization and dropout are applied after each attention block and convolution to regularize the model.

The encoded audio representations produced by the AudioEncoder are then passed to the TextDecoder module to generate text token predictions.

Decoder

References: whisper

The TextDecoder module is responsible for predicting text tokens given an input sequence of encoded audio representations. It contains the TextDecoder class which implements the core decoding functionality.

The TextDecoder class contains a stack of residual attention blocks arranged in a decoder architecture. The blocks utilize multi-head attention from the MultiHeadAttention module to relate the encoded audio representations from the AudioEncoder to the predicted text tokens at each timestep. This cross-attention allows the model to map from the encoded audio to the appropriate text token predictions.

The residual attention blocks in the TextDecoder follow the same structure as those in the AudioEncoder, containing MultiHeadAttention modules for self-attention over the predicted tokens as well as cross-attention to the audio encodings. The blocks also include Conv1d and normalization layers like LayerNorm.

Some key aspects of the TextDecoder implementation include mixed precision support via custom LayerNorm and Linear subclasses to handle different data types efficiently. Caching of attention keys and values via MultiHeadAttention improves performance by reusing computations between timesteps.

Full Model

References: whisper

The Whisper class defined in …/model.py represents the full speech recognition model, combining the AudioEncoder and TextDecoder into a single class. The __init__() method initializes the encoder and decoder, along with other components like embeddings. It also handles moving modules to the specified device like CPU or GPU.

The core forward() method takes a batch of mel spectrograms as input, passes it through the AudioEncoder to get encoded audio representations. It then passes the encodings through the TextDecoder to generate token predictions at each timestep. The forward() method returns the predictions as well as any attention weights or other debugging outputs.

The AudioEncoder contains convolutional and self-attention blocks to encode input spectrograms. It uses Conv1d and ResidualAttentionBlock modules arranged in an encoder stack. The blocks utilize multi-head attention from MultiHeadAttention to relate different parts of the input.

The TextDecoder contains a similar stack of residual attention blocks, except it cross-attends to the encoded audio representations from the encoder using the MultiHeadAttention module. This cross-attention allows mapping from encoded audio to predicted text tokens at each timestep.

Some key aspects of the implementation include mixed precision support via custom LayerNorm and Linear subclasses, and caching of attention keys/values for efficiency via MultiHeadAttention.

Tokenization

References: whisper

The Tokenizer class handles converting text into discrete tokens for model input and output. It provides methods like encode(), decode(), and decode_with_timestamps() to tokenize strings.

The Tokenizer initializes with an Encoding instance which represents the vocabulary and tokenization rules. Encoding contains the mappings and rules for a particular language or model, and handles the core encoding and decoding logic by mapping text to integer IDs.

The Encoding class is initialized by building a rank-frequency mapping from the vocabulary file and populating dictionaries for special tokens. Its encode() and decode() methods iterate through the text, looking up each token in the mappings to convert to and from integer IDs.

The get_encoding() factory function retrieves the correct concrete Encoding subclass based on parameters like the model name. It loads the vocabulary file, constructs the required mappings, and handles configuration specific to each language's rules.

The get_tokenizer() factory ties it all together by calling get_encoding(), configuring a Tokenizer instance with the returned Encoding, and setting additional properties. This provides a simple interface to obtain a fully initialized tokenizer.

Cached properties on Tokenizer like language_token and non_speech_tokens extract commonly needed values from the underlying Encoding in an efficient manner without repeated method calls.

Methods such as split_to_word_tokens() handle any language-specific preprocessing required before tokenization. Overall the classes encapsulate the tokenization logic, while the factories abstract away configuration details.

Evaluation Datasets

References: whisper

The data directory contains datasets that were used to evaluate speech recognition models. It includes short-form English-only datasets, long-form English-only datasets, and multilingual datasets from a variety of sources. Some of the datasets located here include LibriSpeech, TED-LIUM, Common Voice, WSJ, and Multilingual LibriSpeech.

The …/README.md file documents the sources and preprocessing details of each dataset. For the long-form datasets, it provides timestamps used to slice audio segments from TED-LIUM talks and lists source IDs for samples used from the Kincaid46 dataset. No classes or functions are defined in data itself - its main purpose is to provide the datasets and transparency into the data preparation process through …/README.md to enable reproducibility of experiments.

Testing

References: whisper

The tests directory contains comprehensive unit tests for validating the core functionality of the whisper library. Tests are organized into modules that match the package structure, for example test_audio.py tests functions in …/audio.py.

Key test modules include:

The …/conftest.py file defines fixtures like random() that seed random numbers. This ensures tests are reproducible.

Some key implementation details:

In summary, the tests aim to flush out any issues through rigorous validation of core functionality, algorithms, classes and functions to ensure the library meets specifications. They provide a comprehensive suite to quality assure whisper.

Tokenization

References: whisper

Tokenization converts text into discrete tokens for model input and output. The …/tokenizer.py file contains utilities that handle this task.

The main classes defined in this file are Tokenizer and Encoding. Tokenizer provides a thin wrapper around the tiktoken tokenizer for encoding and decoding text. It handles accessing special tokens. Encoding represents the vocabulary and tokenization rules for a particular model/language. It encodes and decodes text by mapping words to integer IDs.

The get_encoding() function retrieves the appropriate Encoding instance based on parameters like the model name and number of supported languages. get_tokenizer() constructs a Tokenizer instance, configuring properties and retrieving the matching Encoding.

Tokenizer initializes special token IDs and delegates core tokenization methods like encode(), decode(), and decode_with_timestamps() to the underlying Encoding. Cached properties like language_token optimize common operations.

Encoding represents the vocabulary and rules. Its constructor builds mappings from words to ranks and IDs. Methods encode text by looking up words and decode IDs back to text.

The classes encapsulate the tokenization logic, while factory functions handle configuration and retrieving instances. Cached properties in Tokenizer improve performance. This implementation provides a flexible yet optimized approach to tokenization.

Tokenizer

References: whisper/tokenizer.py

The Tokenizer class handles converting text to discrete tokens using an Encoding. Tokenizer acts as a thin wrapper around the tiktoken tokenizer that provides access to special tokens. Its main responsibilities are encoding and decoding text.

Tokenizer initializes special token IDs by calling methods on the underlying Encoding instance. Methods like encode(), decode(), and decode_with_timestamps() delegate the core tokenization work to Encoding.

Cached properties such as language_token and non_speech_tokens extract commonly needed values from Encoding for optimized performance of common operations. split_to_word_tokens() handles preprocessing text into word tokens according to language-specific conventions.

The Encoding class represents the vocabulary and tokenization rules for a particular language or model. Its constructor builds a rank-frequency mapping of tokens from the vocabulary file located at …/tokenizer.py. It also constructs mappings of special tokens.

Encoding handles the low-level work of encoding and decoding text. It maps strings to integer IDs and vice versa using the constructed mappings. The get_encoding() factory function retrieves the correct Encoding instance based on parameters like the model name.

Vocab

References: whisper/tokenizer.py

The Encoding class represents the vocabulary and tokenization rules for a particular model or language. It manages a mapping from tokens to integer IDs that is used during encoding and decoding of text.

The constructor builds this mapping from the vocabulary file in …/tokenizer.py. It constructs a rank-frequency mapping that associates each token with its frequency in the training data. This is used to determine the most common tokens.

Special tokens like padding, unknown words, sentence start and end markers are also mapped to dedicated integer IDs. The mappings are exposed via properties for efficient lookup during encoding and decoding.

The encode() and decode() methods handle the core functionality of converting between text and integer IDs. encode() looks up each token in the input text and returns the corresponding IDs. decode() reverses this process.

Cached class attributes optimize common operations. The mappings are also exposed as properties for external use cases like preprocessing.

Encoding

References: whisper/tokenizer.py

The Encoding class represents the vocabulary and tokenization rules for a particular language or model. It handles encoding text into integer token IDs and decoding IDs back into text.

The constructor builds two important mappings - a rank frequency mapping that stores how common each token is based on its frequency in the training data, and a special token mapping that encodes tokens like the start/end of sentence tokens.

The Encoding loads the vocabulary file, which contains the list of all tokens mapped to their integer IDs. It constructs the rank frequency mapping by counting token frequencies in the vocabulary list. The special token mappings are also built from constants defined in the file.

Two main methods, encode() and decode(), handle the core encoding and decoding functionality. encode() takes raw text and uses the vocabulary mapping to convert each token to its integer ID. Any out-of-vocabulary tokens are mapped to a special ID.

decode() reverses this process, taking a list of integer IDs and converting them back to the original token strings. It references the vocabulary mapping to look up tokens from their IDs.

The rank frequency mapping is used to sort tokens from most to least common during encoding and decoding. This preserves the most important semantic information at the start of the encoded sequences.

The Encoding provides all the data structures and methods needed for the Tokenizer class to efficiently encode and decode text. See the Tokenization section for more details on how the Tokenizer interfaces with Encoding.

Decoding

References: whisper/tokenizer.py

The Tokenizer class handles decoding integer token IDs back into text through its decode() method. This method takes a list of token IDs and uses the underlying Encoding instance's mappings to look up the corresponding token strings. It concatenates these strings with whitespace separators to reconstruct the original text.

The Encoding class contains the core mappings from integers to tokens. Its constructor builds mappings from token IDs to strings using the vocabulary file. The decode() method on Encoding simply looks up the token string for each ID in its mappings and concatenates them.

The decode() method must handle special tokens like padding and unknown words differently than regular words. Padding tokens are stripped from the output, while unknown word tokens are replaced with a placeholder. Proper handling of these special cases is important for reconstructing clean text from model predictions.

The Tokenizer caches commonly used values like padding and unknown word IDs for efficient lookup during decoding. It also provides options to control whether special handling is applied, like removing padding or replacing unknown words. These parameters allow flexibility in how decoding is implemented depending on the task.

Evaluation Datasets

References: data

The data directory contains datasets that were used to evaluate speech recognition models during development and testing. It includes several English-only and multilingual datasets covering both short-form and long-form speech.

The …/README.md file documents the source and preprocessing details of each dataset. Some of the datasets included are:

  • LibriSpeech: A large English audio book corpus for speech recognition.
  • TED-LIUM: Transcripts and audio from TED talks. Used for long-form speech recognition evaluation.
  • Common Voice: An open source speech dataset collected via crowdsourcing.
  • WSJ: The Wall Street Journal corpus for benchmarking speech and language models.
  • Multilingual LibriSpeech: An expanded version of LibriSpeech with additional languages.

The …/README.md file also describes scripts used to preprocess datasets, such as eval2000_data_prep.sh for CallHome and Switchboard corpora. It provides timestamps used to slice audio segments from TED-LIUM talks and lists source IDs for samples used from the Kincaid46 dataset.

Testing

References: tests

The whisper library provides a comprehensive suite of unit tests to validate core functionality without exposing implementation details. The tests directory contains test files that utilize PyTest and parameterization to thoroughly exercise key algorithms, classes, and functions. This includes testing for audio processing, text normalization, timing operations, tokenization, and end-to-end speech transcription.

The …/conftest.py file handles shared test configuration and fixtures. It defines the random fixture which seeds the random number generators used across tests, ensuring reproducibility. This is a common pattern to generate random data for tests in a consistent way.

The …/test_audio.py file tests audio loading and spectrogram generation functionality. It loads a sample audio file, validates properties like the sample rate and dimensions, and compares the output of log_mel_spectrogram() when called directly versus on a file path. This verifies correct functionality without exposing implementation details of the underlying audio processing algorithms.

The EnglishNumberNormalizer class in …/test_normalizer.py is tested extensively. It likely contains mappings and algorithms to parse different types of written number representations like "twenty five" into standardized numeric strings like "25". The thorough testing validates these algorithms can correctly handle a wide range of inputs.

The median_filter() function in …/test_timing.py performs median filtering by calling the underlying PyTorch torch.median() operator. It pads the input like SciPy to ensure equivalent results between implementations. Comparing the output to SciPy validates the behavior matches a well-known library.

The transcribe() method tested in …/test_transcribe.py uses the trained whisper model to transcribe audio into text. Tests validate the output transcription text and timing information meet expectations, ensuring the core speech recognition capability is implemented correctly.

Unit Tests

References: tests/conftest.py, tests/test_audio.py, tests/test_normalizer.py, tests/test_timing.py, tests/test_tokenizer.py, tests/test_transcribe.py

The tests directory contains comprehensive unit tests for core functionality using Pytest. Thorough testing is provided through test cases across multiple files that validate different aspects of the library.

The …/conftest.py file defines useful Pytest fixtures like random() for seeding random numbers. This ensures tests are reproducible.

Test cases are implemented in files mirroring the code structure. The …/test_audio.py file contains tests for the whisper.audio module. It loads audio, checks properties, and validates the output of functions like log_mel_spectrogram() matches expectations.

The …/test_normalizer.py file tests the text normalization classes like EnglishNumberNormalizer. It utilizes Pytest parameterization to pass various inputs and check the outputs meet specifications. This validates the normalization algorithms are implemented correctly.

Dynamic time warping, median filtering and other timing functions are tested in …/test_timing.py. Equivalence of CPU and GPU implementations is checked along with correctness on different input shapes and sizes.

Tokenization is exercised in …/test_tokenizer.py by encoding, decoding text, and checking token splitting behavior. This validates the expected mono-lingual and multi-lingual tokenizer functionality.

Model transcription is end-to-end tested in …/test_transcribe.py. It loads a model, transcribes audio, and asserts properties of the result like expected phrases, timing information and tokenized representation matches.

Test Fixtures

References: tests/conftest.py

The random fixture defined in …/conftest.py seeds the random number generators used across tests. This ensures tests that use this fixture receive reproducible random numbers.

The file imports pytest, random, and numpy modules. It defines one fixture - random(). This fixture calls random.seed() and numpy.random.seed() to seed both the Python random module and NumPy's random generator with the value 42.

By seeding the random number generators with a fixed value, any tests using the random fixture are guaranteed to produce the same random numbers each time they are run. This allows tests to reliably generate random test data without changing the results. It is a common pattern when writing tests to seed the random number generators to avoid non-deterministic tests.

The pytest_configure() function adds a "requires_cuda" marker to pytest using config.addinivalue_line(). This marker can then be applied to tests that require a CUDA-enabled environment.

Audio Processing

References: tests/test_audio.py

The …/test_audio.py file contains unit tests that validate core audio loading, padding, and spectrogram generation functionality. The file imports necessary modules like os, numpy, and audio functions from whisper.audio.

The main function is test_audio() which loads a sample audio file from the path "jfk.flac" using the load_audio() function. It asserts properties of the loaded audio like the sample rate and dimensions. The log_mel_spectrogram() function is called directly on the loaded audio and file path to generate mel spectrograms, and np.allclose() confirms the outputs are equal.

The load_audio() function reads audio files and returns a numpy array. log_mel_spectrogram() performs a short-time Fourier transform, maps frequencies to the mel scale, and takes the log to convert to decibels, generating mel spectrogram features from the raw audio. These functions are imported and their implementations are not visible, but these tests validate the expected output and properties of the core audio processing functionality.

Text Normalization

References: tests/test_normalizer.py

The tests in …/test_normalizer.py validate the implementations of various text normalization functions. The EnglishNumberNormalizer, EnglishSpellingNormalizer, and EnglishTextNormalizer classes handle number, spelling, and general text normalization respectively.

The EnglishNumberNormalizer class normalizes written English numbers to standardized numeric strings. It is tested extensively with examples covering whole numbers, fractions, currencies, percentages, dates, and ordinal values. The tests ensure the class properly parses these inputs and converts them using the correct algorithms.

The EnglishSpellingNormalizer class normalizes common spelling variations between British and American English. It contains mappings of alternative spellings to canonical spellings that are exercised by the tests.

The EnglishTextNormalizer class implements minor text cleanups like expanding contractions and standardizing units. The tests validate it correctly handles these types of normalization.

No implementation details of the classes are shown, but the thorough testing provides insight into their supported functionality and examples of inputs and outputs. This allows validating the business logic is properly implemented without exposing internal details.

Timing Functions

References: tests/test_timing.py

The …/test_timing.py file contains unit tests for timing-related functions in the …/timing.py module. It tests the dtw_cpu, dtw_cuda, and median_filter functions.

dtw_cpu implements dynamic time warping (DTW) on CPU using dynamic programming. dtw_cuda is a GPU implementation of DTW that is tested for equivalence against the CPU version. Tests ensure the functions produce the same output on different input tensors.

median_filter performs median filtering by calling torch.median(). It pads the input as SciPy does and validates the output matches SciPy's implementation on tensors with varying shapes and filter widths. Equivalence of median_filter on CPU and GPU is also tested.

The tests are parameterized over input sizes and shapes using PyTest fixtures to cover a wide range of cases. NumPy, PyTorch, SciPy and CUDA are leveraged for flexible and accurate validation. This provides a comprehensive test suite for common timing operations, validating correctness and equivalence of implementations across devices and parameters.

Tokenization

References: tests/test_tokenizer.py

The …/test_tokenizer.py file contains unit tests that exercise the encoding, decoding, and splitting functionality of the whisper tokenizer. These tests validate the core operations work as expected for both mono-lingual and multi-lingual use cases.

The test_tokenizer() function checks that a tokenizer retrieved with get_tokenizer() has the expected start of text token defined, and that language codes and tokens are aligned properly.

Encoding and decoding of text is tested with test_multilingual_tokenizer(). It demonstrates encoding Korean text with mono-lingual and multi-lingual tokenizers, and verifies the results are the same. This test also checks the multi-lingual tokenizer returns fewer tokens by splitting on language boundaries.

Token splitting is exercised by test_split_on_unicode(). It passes a list of tokens to the split_tokens_on_unicode() method, which is responsible for separating tokens into subword pieces based on Unicode characters. The method's return values are validated against expectations.

In summary, these tests cover the key responsibilities of the whisper tokenizer - encoding text into integer IDs, decoding IDs back into text, and properly splitting multi-lingual tokens. They validate the tokenizer provides the expected functionality without relying on implementation details.

Transcription

References: tests/test_transcribe.py

The …/test_transcribe.py file contains unit tests that validate end-to-end audio transcription functionality. The main test_transcribe function loads a pre-trained Whisper model and uses it to transcribe an audio file. It then performs several checks on the transcription output to ensure it meets expectations.

The test_transcribe function runs the transcription test for each available model by calling whisper.available_models(). It loads the model onto the correct device using model.to(device) and calls the model's transcribe method, passing the path to the "jfk.flac" audio file.

Some key checks performed on the transcription output include:

  • Comparing the predicted transcription text to expected phrases from the audio using string matching assertions.
  • Checking the transcription language matches what is expected for the audio.
  • Validating the full text by concatenating each segment matches the overall transcription text.
  • Comparing the tokenized representation generated by get_tokenizer to what the model produces.

Timing information returned from transcribe is also validated. Tests ensure start time is less than end time and check a particular word's timing falls within an expected range.

The main logic under test is the transcribe method, which handles loading a model, preprocessing audio, running inference, and returning the transcription result. These unit tests aim to validate this end-to-end transcription functionality works as expected on real audio data.