GumbelSoftmaxBottleneck(num_embeddings: int, embedding_dim: int, temperature: float = 1.0, hard: bool = True, input_is_logits: bool = False, use_bias: bool = True, kl_weight: float = 1.0, **kwargs)
Discrete bottleneck via Gumbel-Softmax (Concrete) with optional straight-through hard one-hot.
Inputs
x: [..., Din] (features) -- if input_is_logits=False (default), we learn a linear proj to K logits
OR
x: [..., K] (logits) -- if input_is_logits=True, we treat last dim as K logits directly
Outputs
z: [..., D] expected embedding z = soft_one_hot @ embed (D = embedding_dim)
Adds loss
kl_weight * mean_bits_per_index (KL(q || Uniform(K)) in bits, averaged over tokens)
Tracks metrics (logged via metrics):
- gs_bits_per_index (lower bound, bits/index)
- gs_perplexity (empirical perplexity from hard argmax histogram)
- gs_usage (fraction of codes used at least once in the batch)
- gs_temperature (current τ; useful when annealing)
Source code in helia_edge/layers/gumbel_softmax_bottleneck.py
| def __init__(
self,
num_embeddings: int,
embedding_dim: int,
temperature: float = 1.0,
hard: bool = True,
input_is_logits: bool = False,
use_bias: bool = True,
kl_weight: float = 1.0,
**kwargs,
):
super().__init__(**kwargs)
if num_embeddings <= 1 or embedding_dim <= 0:
raise ValueError("num_embeddings must be >=2 and embedding_dim > 0.")
self.K = int(num_embeddings)
self.D = int(embedding_dim)
self.hard = bool(hard)
self.input_is_logits = bool(input_is_logits)
self.use_bias = bool(use_bias)
self.kl_weight = float(kl_weight)
# temperature stored as non-trainable weight (easy to anneal via a callback)
self.tau = self.add_weight(
name="temperature",
shape=(),
initializer=keras.initializers.Constant(float(temperature)),
trainable=False,
dtype="float32",
)
# Trackers
self._bpi = keras.metrics.Mean(name="gs_bits_per_index") # KL lower bound (bits/index)
self._perp = keras.metrics.Mean(name="gs_perplexity")
self._usage = keras.metrics.Mean(name="gs_usage")
self._tau = keras.metrics.Mean(name="gs_temperature")
# weights created in build()
self._proj = None # [Din, K] if input_is_logits=False
self._bias = None # [K] if use_bias
self._embed = None # [K, D]
|