Skip to content

contrastive

Contrastive Trainer API

This module provides a trainer for contrastive learning.

Classes:

Classes

ContrastiveTrainer

ContrastiveTrainer(encoder: keras.Model, projector: keras.Model | tuple[keras.Model, keras.Model], augmenter: keras.Layer | tuple[keras.Layer, keras.Layer] | None = None, probe: keras.Layer | keras.Model | None = None)

Creates a self-supervised contrastive trainer for a model.

Parameters:

  • encoder (Model) –

    The encoder model to be trained.

  • projector (Model | tuple[Model, Model]) –

    The projector model to be trained.

  • augmenter (Layer | tuple[Layer, Layer] | None, default: None ) –

    The augmenter to be used for data augmentation.

  • probe (Layer | Model | None, default: None ) –

    The probe model to be trained. If None, no probe is used.

Source code in neuralspot_edge/trainers/contrastive.py
def __init__(
    self,
    encoder: keras.Model,
    projector: keras.Model | tuple[keras.Model, keras.Model],
    augmenter: keras.Layer | tuple[keras.Layer, keras.Layer] | None = None,
    probe: keras.Layer | keras.Model | None = None,
):
    """Creates a self-supervised contrastive trainer for a model.

    Args:
        encoder (keras.Model): The encoder model to be trained.
        projector (keras.Model|tuple[keras.Model, keras.Model]): The projector model to be trained.
        augmenter (keras.Layer|tuple[keras.Layer, keras.Layer]|None): The augmenter to be used for data augmentation.
        probe (keras.Layer|keras.Model|None): The probe model to be trained. If None, no probe is used.

    """
    super().__init__()

    if len(encoder.output.shape) != 2:
        raise ValueError(
            f"`encoder` must have a flattened output. Expected "
            f"rank(encoder.output.shape)=2, got "
            f"encoder.output.shape={encoder.output.shape}"
        )

    if isinstance(augmenter, tuple) and len(augmenter) != 2:
        raise ValueError("`augmenter` must be either a single augmenter or a tuple of exactly 2 augmenters.")

    if isinstance(projector, tuple) and len(projector) != 2:
        raise ValueError("`projector` must be either a single augmenter or a tuple of exactly 2 augmenters.")

    if augmenter is None:
        self.augmenters = (keras.layers.Lambda(lambda x: x), keras.layers.Lambda(lambda x: x))
    elif isinstance(augmenter, tuple):
        self.augmenters = augmenter
    else:
        self.augmenters = (augmenter, augmenter)

    self.encoder = encoder

    # Check to see if the projector is being shared or are distinct.
    self._is_shared_projector = True if not isinstance(projector, tuple) else False
    self.projectors = projector if type(projector) is tuple else (projector, projector)
    self.probe = probe

    self.loss_metric = keras.metrics.Mean(name="loss")
    self.encoder_metrics = []
    if probe is not None:
        self.probe_loss_metric = keras.metrics.Mean(name="probe_loss")
        self.probe_metrics = []

Functions

save
save(filepath, overwrite=True, zipped=True, **kwargs)

We only save the encoder model

Source code in neuralspot_edge/trainers/contrastive.py
def save(self, filepath, overwrite=True, zipped=True, **kwargs):
    """We only save the encoder model"""
    self.encoder.save(filepath, overwrite=overwrite, zipped=zipped, **kwargs)

Functions