Skip to content

ema_residual_vector_quantizer

Residual Vector Quantizer with Exponential Moving Average codebook updates.

Replaces the gradient-based codebook loss with EMA updates to codebook embeddings, following van den Oord et al. 2017 (VQ-VAE). Only the commitment loss is back-propagated; codebook vectors are updated via running averages of assigned encoder outputs.

Classes

EmaResidualVectorQuantizer

EmaResidualVectorQuantizer(num_levels: int, num_embeddings: int | Sequence[int], embedding_dim: int, beta: float = 0.25, ema_decay: float = 0.99, epsilon: float = 1e-05, **kwargs)

Residual VQ with EMA codebook updates.

Instead of learning codebook embeddings via gradient descent (which requires a codebook loss term), this layer maintains exponential moving averages of cluster assignment counts and embedding sums. Codebook vectors are derived from these running statistics with Laplace smoothing for numerical stability.

Only the commitment loss is back-propagated through the encoder; the straight-through estimator copies gradients from the decoder to the encoder as in the standard VQ-VAE.

Input: [..., D] (last dim = embedding_dim) Output: [..., D] (sum of per-level dequantized vectors)

Parameters:

  • num_levels

    (int) –

    Number of residual VQ stages (M >= 1).

  • num_embeddings

    (int | Sequence[int]) –

    Codebook size K per level (int or per-level list).

  • embedding_dim

    (int) –

    Latent dimensionality D.

  • beta

    (float, default: 0.25 ) –

    Commitment loss coefficient.

  • ema_decay

    (float, default: 0.99 ) –

    EMA decay rate for codebook updates (0.990.999 typical).

  • epsilon

    (float, default: 1e-05 ) –

    Small constant for Laplace smoothing of cluster counts.

Metrics (logged via metrics property): - rvq_l{l}_perplexity, rvq_l{l}_usage, rvq_l{l}_bits_per_index - rvq_perplexity_mean, rvq_usage_mean, rvq_bits_per_index_sum (entropy lower bound)

Losses added per level
  • beta * ||stop(q_l) - r_l||^2 (commitment only; no codebook gradient loss)

Example:

rvq = EmaResidualVectorQuantizer(
    num_levels=4,
    num_embeddings=64,
    embedding_dim=16,
    ema_decay=0.99,
)
y = rvq(z, training=True)          # forward + EMA update
y, indices = rvq(z, return_indices=True)  # also return codes
References
  • van den Oord, A., Vinyals, O. & Kavukcuoglu, K. (2017). Neural Discrete Representation Learning. NeurIPS.
Source code in helia_edge/layers/ema_residual_vector_quantizer.py
def __init__(
    self,
    num_levels: int,
    num_embeddings: int | Sequence[int],
    embedding_dim: int,
    beta: float = 0.25,
    ema_decay: float = 0.99,
    epsilon: float = 1e-5,
    **kwargs,
):
    super().__init__(**kwargs)
    if num_levels < 1 or embedding_dim <= 0 or beta <= 0:
        raise ValueError("num_levels>=1, embedding_dim>0, beta>0 required.")
    self.M = int(num_levels)
    self.D = int(embedding_dim)
    if isinstance(num_embeddings, (list, tuple)):
        if len(num_embeddings) != self.M:
            raise ValueError("num_embeddings list must have length = num_levels.")
        self.Ks = [int(k) for k in num_embeddings]
    else:
        self.Ks = [int(num_embeddings)] * self.M
    self.beta = float(beta)
    self.ema_decay = float(ema_decay)
    self.epsilon = float(epsilon)

    # Per-level metric trackers
    self._lvl_perp = [
        keras.metrics.Mean(name=f"rvq_l{lvl + 1}_perplexity")
        for lvl in range(self.M)
    ]
    self._lvl_usage = [
        keras.metrics.Mean(name=f"rvq_l{lvl + 1}_usage")
        for lvl in range(self.M)
    ]
    self._lvl_bpi = [
        keras.metrics.Mean(name=f"rvq_l{lvl + 1}_bits_per_index")
        for lvl in range(self.M)
    ]
    # Aggregates
    self._perp_mean = keras.metrics.Mean(name="rvq_perplexity_mean")
    self._usage_mean = keras.metrics.Mean(name="rvq_usage_mean")
    self._bpi_sum = keras.metrics.Mean(name="rvq_bits_per_index_sum")

    self._codebooks: list = []
    self._ema_counts: list = []
    self._ema_weights: list = []

Attributes

metrics property
metrics

Expose per-level + aggregate metrics so Model.fit logs them.

Functions

call
call(x: keras.KerasTensor, training: bool = False, return_indices: bool = False) -> keras.KerasTensor | tuple[keras.KerasTensor, list[keras.KerasTensor]]

Quantize x through all residual levels.

Parameters:

  • x
    (KerasTensor) –

    [..., D] latent to be quantized.

  • training
    (bool, default: False ) –

    If True, run EMA codebook updates.

  • return_indices
    (bool, default: False ) –

    If True, also return per-level flat indices.

Returns:

  • KerasTensor | tuple[KerasTensor, list[KerasTensor]]

    y or (y, indices_list): dequantized vector and optional

  • KerasTensor | tuple[KerasTensor, list[KerasTensor]]

    per-level index tensors.

Source code in helia_edge/layers/ema_residual_vector_quantizer.py
def call(
    self,
    x: keras.KerasTensor,
    training: bool = False,
    return_indices: bool = False,
) -> keras.KerasTensor | tuple[keras.KerasTensor, list[keras.KerasTensor]]:
    """Quantize *x* through all residual levels.

    Args:
        x: ``[..., D]`` latent to be quantized.
        training: If ``True``, run EMA codebook updates.
        return_indices: If ``True``, also return per-level flat indices.

    Returns:
        ``y`` or ``(y, indices_list)``: dequantized vector and optional
        per-level index tensors.
    """
    x = keras.ops.convert_to_tensor(x, dtype=self.compute_dtype)
    shape = keras.ops.shape(x)
    flat = keras.ops.reshape(x, (-1, self.D))

    residual = flat
    q_sum = keras.ops.zeros_like(flat)
    indices_list = []
    perp_vals, usage_vals, bpi_vals = [], [], []

    for lvl, (K, codebook) in enumerate(zip(self.Ks, self._codebooks)):
        idx, q_l = self._nearest(residual, codebook)
        indices_list.append(idx)
        q_sum = q_sum + q_l

        # Commitment loss only (EMA handles codebook)
        ql_st = keras.ops.stop_gradient(q_l)
        commitment = keras.ops.mean(keras.ops.square(ql_st - residual))
        self.add_loss(self.beta * commitment)

        # EMA codebook update during training
        if training:
            self._ema_update(lvl, idx, residual, K)

        residual = residual - ql_st

        # Per-level metrics
        one_hot = keras.ops.one_hot(idx, K)
        probs = keras.ops.mean(one_hot, axis=0)
        eps = keras.ops.convert_to_tensor(1e-10, dtype=self.compute_dtype)
        log2 = keras.ops.log(
            keras.ops.convert_to_tensor(2.0, self.compute_dtype)
        )
        H = -keras.ops.sum(probs * (keras.ops.log(probs + eps) / log2))
        perp = keras.ops.exp(H * log2)
        usage = keras.ops.sum(
            keras.ops.cast(probs > 0, self.compute_dtype)
        ) / float(K)

        self._lvl_perp[lvl].update_state(perp)
        self._lvl_usage[lvl].update_state(usage)
        self._lvl_bpi[lvl].update_state(H)
        perp_vals.append(perp)
        usage_vals.append(usage)
        bpi_vals.append(H)

    # Aggregate metrics
    self._perp_mean.update_state(sum(perp_vals) / float(self.M))
    self._usage_mean.update_state(sum(usage_vals) / float(self.M))
    self._bpi_sum.update_state(sum(bpi_vals))

    # Straight-through estimator
    y_flat = flat + keras.ops.stop_gradient(q_sum - flat)
    y = keras.ops.reshape(y_flat, shape)
    return (y, indices_list) if return_indices else y
encode
encode(x: keras.KerasTensor) -> list[keras.KerasTensor]

Return list of per-level flat index tensors [N] (no gradients).

Source code in helia_edge/layers/ema_residual_vector_quantizer.py
def encode(self, x: keras.KerasTensor) -> list[keras.KerasTensor]:
    """Return list of per-level flat index tensors ``[N]`` (no gradients)."""
    x = keras.ops.convert_to_tensor(x, dtype=self.compute_dtype)
    flat = keras.ops.reshape(x, (-1, self.D))
    residual = flat
    indices = []
    for codebook in self._codebooks:
        idx, q_l = self._nearest(residual, codebook)
        indices.append(idx)
        residual = residual - q_l
    return indices
decode
decode(indices_list: list[keras.KerasTensor], original_shape: tuple[int, ...]) -> keras.KerasTensor

Sum per-level code vectors from indices_list and reshape.

Source code in helia_edge/layers/ema_residual_vector_quantizer.py
def decode(
    self,
    indices_list: list[keras.KerasTensor],
    original_shape: tuple[int, ...],
) -> keras.KerasTensor:
    """Sum per-level code vectors from *indices_list* and reshape."""
    q_sum = None
    for idx, codebook in zip(indices_list, self._codebooks):
        q_l = keras.ops.take(codebook, idx, axis=0)
        q_sum = q_l if q_sum is None else (q_sum + q_l)
    return keras.ops.reshape(q_sum, original_shape)