ema_residual_vector_quantizer
Residual Vector Quantizer with Exponential Moving Average codebook updates.
Replaces the gradient-based codebook loss with EMA updates to codebook embeddings, following van den Oord et al. 2017 (VQ-VAE). Only the commitment loss is back-propagated; codebook vectors are updated via running averages of assigned encoder outputs.
Classes
EmaResidualVectorQuantizer
EmaResidualVectorQuantizer(num_levels: int, num_embeddings: int | Sequence[int], embedding_dim: int, beta: float = 0.25, ema_decay: float = 0.99, epsilon: float = 1e-05, **kwargs)
Residual VQ with EMA codebook updates.
Instead of learning codebook embeddings via gradient descent (which requires a codebook loss term), this layer maintains exponential moving averages of cluster assignment counts and embedding sums. Codebook vectors are derived from these running statistics with Laplace smoothing for numerical stability.
Only the commitment loss is back-propagated through the encoder; the straight-through estimator copies gradients from the decoder to the encoder as in the standard VQ-VAE.
Input: [..., D] (last dim = embedding_dim)
Output: [..., D] (sum of per-level dequantized vectors)
Parameters:
-
(num_levelsint) –Number of residual VQ stages (
M >= 1). -
(num_embeddingsint | Sequence[int]) –Codebook size
Kper level (int or per-level list). -
(embedding_dimint) –Latent dimensionality
D. -
(betafloat, default:0.25) –Commitment loss coefficient.
-
(ema_decayfloat, default:0.99) –EMA decay rate for codebook updates (
0.99–0.999typical). -
(epsilonfloat, default:1e-05) –Small constant for Laplace smoothing of cluster counts.
Metrics (logged via metrics property):
- rvq_l{l}_perplexity, rvq_l{l}_usage,
rvq_l{l}_bits_per_index
- rvq_perplexity_mean, rvq_usage_mean,
rvq_bits_per_index_sum (entropy lower bound)
Losses added per level
beta * ||stop(q_l) - r_l||^2(commitment only; no codebook gradient loss)
Example:
rvq = EmaResidualVectorQuantizer(
num_levels=4,
num_embeddings=64,
embedding_dim=16,
ema_decay=0.99,
)
y = rvq(z, training=True) # forward + EMA update
y, indices = rvq(z, return_indices=True) # also return codes
References
- van den Oord, A., Vinyals, O. & Kavukcuoglu, K. (2017). Neural Discrete Representation Learning. NeurIPS.
Source code in helia_edge/layers/ema_residual_vector_quantizer.py
Attributes
Functions
call
call(x: keras.KerasTensor, training: bool = False, return_indices: bool = False) -> keras.KerasTensor | tuple[keras.KerasTensor, list[keras.KerasTensor]]
Quantize x through all residual levels.
Parameters:
-
(xKerasTensor) –[..., D]latent to be quantized. -
(trainingbool, default:False) –If
True, run EMA codebook updates. -
(return_indicesbool, default:False) –If
True, also return per-level flat indices.
Returns:
-
KerasTensor | tuple[KerasTensor, list[KerasTensor]]–yor(y, indices_list): dequantized vector and optional -
KerasTensor | tuple[KerasTensor, list[KerasTensor]]–per-level index tensors.
Source code in helia_edge/layers/ema_residual_vector_quantizer.py
encode
Return list of per-level flat index tensors [N] (no gradients).
Source code in helia_edge/layers/ema_residual_vector_quantizer.py
decode
Sum per-level code vectors from indices_list and reshape.