VQAutoencoder(encoder: keras.Model, vq: VectorQuantizer, decoder: keras.Model, **kwargs)
Convenience wrapper around (encoder -> VectorQuantizer -> decoder).
- Supports extra reconstruction-side losses and metrics.
- Can return discrete code indices from the VQ bottleneck.
- Exposes VQ layer metrics alongside base model metrics.
Initialize the vector-quantized autoencoder.
Parameters:
-
encoder
(Model)
–
Encoder model producing continuous latents.
-
vq
(VectorQuantizer)
–
VectorQuantizer layer that discretizes latents.
-
decoder
(Model)
–
Decoder model mapping bottleneck outputs to reconstructions.
Source code in helia_edge/trainers/vq_autoencoder.py
| def __init__(self, encoder: keras.Model, vq: VectorQuantizer, decoder: keras.Model, **kwargs):
"""Initialize the vector-quantized autoencoder.
Args:
encoder: Encoder model producing continuous latents.
vq: VectorQuantizer layer that discretizes latents.
decoder: Decoder model mapping bottleneck outputs to reconstructions.
"""
super().__init__(**kwargs)
self.encoder = encoder
self.vq = vq
self.decoder = decoder
self._recon_loss = None
self._extra_loss_fns = []
self._extra_metric_objs = [] # Metric trackers
self._extra_metric_fns = [] # (tracker, callable) pairs
|
Functions
call
call(x, training=False, return_indices: bool = False)
Run encoder -> VQ bottleneck -> decoder.
Parameters:
-
x
–
-
training
–
Whether to run in training mode (affects encoder/decoder/VQ).
-
return_indices
(bool, default:
False
)
–
If True, also return discrete code indices.
Returns:
-
–
Reconstruction, optionally with indices.
Source code in helia_edge/trainers/vq_autoencoder.py
| def call(self, x, training=False, return_indices: bool = False):
"""Run encoder -> VQ bottleneck -> decoder.
Args:
x: Input batch.
training: Whether to run in training mode (affects encoder/decoder/VQ).
return_indices: If True, also return discrete code indices.
Returns:
Reconstruction, optionally with indices.
"""
z = self.encoder(x, training=training)
v = self.vq(z, return_indices=return_indices)
if return_indices:
zq, indices = v
else:
zq, indices = v, None
y = self.decoder(zq, training=training)
return (y, indices) if return_indices else y
|
compile
compile(optimizer: keras.optimizers.Optimizer, loss: keras.losses.Loss | None = None, metrics: list | None = None, extra_losses: list | None = None, extra_metrics: list | None = None, **kwargs)
Compile with optional extra losses/metrics.
Parameters:
-
optimizer
(Optimizer)
–
-
loss
(Loss | None, default:
None
)
–
base reconstruction loss (e.g., keras.losses.MeanSquaredError())
-
metrics
(list | None, default:
None
)
–
standard Keras metrics (Metric instances or callables)
-
(
list | None, default:
None
)
–
list[Callable(y_true, y_pred) -> scalar]
-
(
list | None, default:
None
)
–
list of Metric OR Callable(y_true, y_pred) -> scalar
Source code in helia_edge/trainers/vq_autoencoder.py
| def compile(
self,
optimizer: keras.optimizers.Optimizer,
loss: keras.losses.Loss | None = None,
metrics: list | None = None,
extra_losses: list | None = None,
extra_metrics: list | None = None,
**kwargs,
):
"""
Compile with optional extra losses/metrics.
Args:
optimizer: Keras optimizer
loss: base reconstruction loss (e.g., keras.losses.MeanSquaredError())
metrics: standard Keras metrics (Metric instances or callables)
extra_losses: list[Callable(y_true, y_pred) -> scalar]
extra_metrics: list of Metric OR Callable(y_true, y_pred) -> scalar
"""
super().compile(optimizer=optimizer, metrics=metrics or [], **kwargs)
self._recon_loss = loss
self._extra_loss_fns = list(extra_losses or [])
# Normalize extra_metrics into Metric objects
self._extra_metric_objs.clear()
self._extra_metric_fns.clear()
for m in extra_metrics or []:
if isinstance(m, keras.metrics.Metric):
self._extra_metric_objs.append(m)
else:
name = getattr(m, "__name__", "extra_metric")
tracker = keras.metrics.Mean(name=name)
self._extra_metric_objs.append(tracker)
self._extra_metric_fns.append((tracker, m))
|
compute_loss
compute_loss(x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False)
Compute total loss = recon + extra losses + layer-added losses.
Source code in helia_edge/trainers/vq_autoencoder.py
| def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False):
"""Compute total loss = recon + extra losses + layer-added losses."""
# Base reconstruction loss (respect sample_weight if provided)
total = keras.ops.convert_to_tensor(0.0, dtype=self.compute_dtype)
if self._recon_loss is not None and y is not None and y_pred is not None:
if sample_weight is not None:
total = total + self._recon_loss(y, y_pred, sample_weight=sample_weight)
else:
total = total + self._recon_loss(y, y_pred)
# Extra user-defined losses
for fn in self._extra_loss_fns:
total = total + fn(y, y_pred)
# Include layer-added losses (e.g., VQ commitment/codebook + any regularizers)
if self.losses:
total = total + keras.ops.add_n(self.losses)
return total
|
compute_metrics
compute_metrics(x, y, y_pred, sample_weight=None)
Update compiled metrics plus extra metric trackers.
Source code in helia_edge/trainers/vq_autoencoder.py
| def compute_metrics(self, x, y, y_pred, sample_weight=None):
"""Update compiled metrics plus extra metric trackers."""
# Update any compiled metrics (e.g., keras.metrics.MeanSquaredError())
results = super().compute_metrics(x, y, y_pred, sample_weight)
# Update extra metric trackers that wrap callables
for tracker, fn in self._extra_metric_fns:
val = fn(y, y_pred)
tracker.update_state(val)
results[tracker.name] = tracker.result()
return results
|