Skip to content

residual_vector_quantizer

Classes

ResidualVectorQuantizer

ResidualVectorQuantizer(num_levels, num_embeddings, embedding_dim, beta=0.25, **kwargs)

Residual Vector Quantizer (RVQ) with straight-through estimator.

Input: [..., D] (last dim = embedding_dim) Output: [..., D] (sum of per-level dequantized vectors; gradients pass through x)

Parameters:

  • num_levels

    int, number of residual VQ stages (M >= 1)

  • num_embeddings

    int OR sequence[int], codebook size K for each level

  • embedding_dim

    int, latent dimensionality D

  • beta

    float, commitment coefficient per level

Metrics (logged via metrics property): - rvq_l{l}_perplexity, rvq_l{l}_usage, rvq_l{l}_bits_per_index - rvq_perplexity_mean, rvq_usage_mean, rvq_bits_per_index_sum (entropy lower bound)

Losses added per level
  • beta * ||stop(q_l) - r_l||^2 + ||q_l - stop(r_l)||^2, where r_l is the current residual and q_l the level-l code vector.
Source code in helia_edge/layers/residual_vector_quantizer.py
def __init__(self, num_levels, num_embeddings, embedding_dim, beta=0.25, **kwargs):
    super().__init__(**kwargs)
    if num_levels < 1 or embedding_dim <= 0 or beta <= 0:
        raise ValueError("num_levels>=1, embedding_dim>0, beta>0 required.")
    self.M = int(num_levels)
    self.D = int(embedding_dim)
    # Allow int or per-level list/tuple for K
    if isinstance(num_embeddings, (list, tuple)):
        if len(num_embeddings) != self.M:
            raise ValueError("num_embeddings list must have length = num_levels.")
        self.Ks = [int(k) for k in num_embeddings]
    else:
        self.Ks = [int(num_embeddings)] * self.M
    self.beta = float(beta)

    # Per-level metric trackers
    self._lvl_perp = [keras.metrics.Mean(name=f"rvq_l{lvl + 1}_perplexity") for lvl in range(self.M)]
    self._lvl_usage = [keras.metrics.Mean(name=f"rvq_l{lvl + 1}_usage") for lvl in range(self.M)]
    self._lvl_bpi = [keras.metrics.Mean(name=f"rvq_l{lvl + 1}_bits_per_index") for lvl in range(self.M)]
    # Aggregates
    self._perp_mean = keras.metrics.Mean(name="rvq_perplexity_mean")
    self._usage_mean = keras.metrics.Mean(name="rvq_usage_mean")
    self._bpi_sum = keras.metrics.Mean(name="rvq_bits_per_index_sum")

    self._codebooks = []  # created in build()

Functions

call
call(x, return_indices: bool = False)

Parameters:

  • x

    [..., D] latent to be quantized.

  • return_indices
    (bool, default: False ) –

    if True, also returns list of flat indices (one tensor per level).

Returns:

  • y or (y, indices_list): dequantized vector and optional per-level indices.

Source code in helia_edge/layers/residual_vector_quantizer.py
def call(self, x, return_indices: bool = False):
    """
    Args:
      x: [..., D] latent to be quantized.
      return_indices: if True, also returns list of flat indices (one tensor per level).

    Returns:
      y or (y, indices_list): dequantized vector and optional per-level indices.
    """
    x = keras.ops.convert_to_tensor(x, dtype=self.compute_dtype)
    shape = keras.ops.shape(x)
    flat = keras.ops.reshape(x, (-1, self.D))  # [N, D]

    residual = flat  # r_1 = x
    q_sum = keras.ops.zeros_like(flat)  # accumulate ∑ q_l
    indices_list = []
    perp_vals, usage_vals, bpi_vals = [], [], []

    for lvl, (K, codebook) in enumerate(zip(self.Ks, self._codebooks)):
        idx, q_l = self._nearest(residual, codebook)  # [N], [N,D]
        indices_list.append(idx)
        q_sum = q_sum + q_l  # accumulate
        # Losses for this level on its residual
        ql_st = keras.ops.stop_gradient(q_l)
        res_st = keras.ops.stop_gradient(residual)
        commitment = keras.ops.mean(keras.ops.square(ql_st - residual))
        codebook_loss = keras.ops.mean(keras.ops.square(q_l - res_st))
        self.add_loss(self.beta * commitment + codebook_loss)

        # Update residual for next stage (quantize the residual)
        residual = residual - ql_st

        # Metrics for this level
        one_hot = keras.ops.one_hot(idx, K)  # [N,K]
        probs = keras.ops.mean(one_hot, axis=0)  # [K]
        eps = keras.ops.convert_to_tensor(1e-10, dtype=self.compute_dtype)
        log2 = keras.ops.log(keras.ops.convert_to_tensor(2.0, self.compute_dtype))
        H = -keras.ops.sum(probs * (keras.ops.log(probs + eps) / log2))  # bits/index
        # perp = keras.ops.pow(keras.ops.convert_to_tensor(2.0, self.compute_dtype), H)
        perp = keras.ops.exp(H * log2)
        usage = keras.ops.sum(keras.ops.cast(probs > 0, self.compute_dtype)) / float(K)

        self._lvl_perp[lvl].update_state(perp)
        self._lvl_usage[lvl].update_state(usage)
        self._lvl_bpi[lvl].update_state(H)
        perp_vals.append(perp)
        usage_vals.append(usage)
        bpi_vals.append(H)

    # Aggregate metrics across levels
    perp_mean = sum(perp_vals) / float(self.M)
    usage_mean = sum(usage_vals) / float(self.M)
    bpi_sum = sum(bpi_vals)  # total entropy lower bound
    self._perp_mean.update_state(perp_mean)
    self._usage_mean.update_state(usage_mean)
    self._bpi_sum.update_state(bpi_sum)

    # Straight-through estimator for the whole stack: forward=q_sum, backward=identity
    y_flat = flat + keras.ops.stop_gradient(q_sum - flat)
    y = keras.ops.reshape(y_flat, shape)

    return (y, indices_list) if return_indices else y
encode
encode(x)

Return list of per-level flat index tensors [N] (no gradients).

Source code in helia_edge/layers/residual_vector_quantizer.py
def encode(self, x):
    """Return list of per-level flat index tensors [N] (no gradients)."""
    x = keras.ops.convert_to_tensor(x, dtype=self.compute_dtype)
    flat = keras.ops.reshape(x, (-1, self.D))
    residual = flat
    indices = []
    for K, codebook in zip(self.Ks, self._codebooks):
        idx, q_l = self._nearest(residual, codebook)
        indices.append(idx)
        residual = residual - q_l
    return indices
decode
decode(indices_list, original_shape)

Sum per-level code vectors from indices_list and reshape to original_shape.

Source code in helia_edge/layers/residual_vector_quantizer.py
def decode(self, indices_list, original_shape):
    """Sum per-level code vectors from indices_list and reshape to original_shape."""
    # indices_list: list of 1D int tensors [N]
    q_sum = None
    for idx, codebook in zip(indices_list, self._codebooks):
        q_l = keras.ops.take(codebook, idx, axis=0)
        q_sum = q_l if q_sum is None else (q_sum + q_l)
    return keras.ops.reshape(q_sum, original_shape)