Skip to content

gumbel_softmax_bottleneck

Classes

GumbelSoftmaxBottleneck

GumbelSoftmaxBottleneck(num_embeddings: int, embedding_dim: int, temperature: float = 1.0, hard: bool = True, input_is_logits: bool = False, use_bias: bool = True, kl_weight: float = 1.0, **kwargs)

Discrete bottleneck via Gumbel-Softmax (Concrete) with optional straight-through hard one-hot.

Inputs

x: [..., Din] (features) -- if input_is_logits=False (default), we learn a linear proj to K logits OR x: [..., K] (logits) -- if input_is_logits=True, we treat last dim as K logits directly

Outputs

z: [..., D] expected embedding z = soft_one_hot @ embed (D = embedding_dim)

Adds loss

kl_weight * mean_bits_per_index (KL(q || Uniform(K)) in bits, averaged over tokens)

Tracks metrics (logged via metrics): - gs_bits_per_index (lower bound, bits/index) - gs_perplexity (empirical perplexity from hard argmax histogram) - gs_usage (fraction of codes used at least once in the batch) - gs_temperature (current τ; useful when annealing)

Source code in helia_edge/layers/gumbel_softmax_bottleneck.py
def __init__(
    self,
    num_embeddings: int,
    embedding_dim: int,
    temperature: float = 1.0,
    hard: bool = True,
    input_is_logits: bool = False,
    use_bias: bool = True,
    kl_weight: float = 1.0,
    **kwargs,
):
    super().__init__(**kwargs)
    if num_embeddings <= 1 or embedding_dim <= 0:
        raise ValueError("num_embeddings must be >=2 and embedding_dim > 0.")
    self.K = int(num_embeddings)
    self.D = int(embedding_dim)

    self.hard = bool(hard)
    self.input_is_logits = bool(input_is_logits)
    self.use_bias = bool(use_bias)
    self.kl_weight = float(kl_weight)

    # temperature stored as non-trainable weight (easy to anneal via a callback)
    self.tau = self.add_weight(
        name="temperature",
        shape=(),
        initializer=keras.initializers.Constant(float(temperature)),
        trainable=False,
        dtype="float32",
    )

    # Trackers
    self._bpi = keras.metrics.Mean(name="gs_bits_per_index")  # KL lower bound (bits/index)
    self._perp = keras.metrics.Mean(name="gs_perplexity")
    self._usage = keras.metrics.Mean(name="gs_usage")
    self._tau = keras.metrics.Mean(name="gs_temperature")

    # weights created in build()
    self._proj = None  # [Din, K] if input_is_logits=False
    self._bias = None  # [K]     if use_bias
    self._embed = None  # [K, D]