Skip to content

simclr

SimCLR Trainer API

This module contains the implementation of a SimCLR trainer that can be used to train a model using the SimCLR approach.

Classes:

Classes

SimCLRTrainer

SimCLRTrainer(encoder: keras.Model, projector: keras.Model | None = None, **kwargs)

Creates a SimCLRTrainer.

If no projector is provided, a default one will be created based on paper.

References

Parameters:

  • encoder (Model) –

    The encoder model.

  • projector (Model, default: None ) –

    The projector model. Defaults to None.

Source code in neuralspot_edge/trainers/simclr.py
def __init__(
    self,
    encoder: keras.Model,
    projector: keras.Model | None = None,
    **kwargs,
):
    """Creates a SimCLRTrainer.

    If no projector is provided, a default one will be created based on paper.

    References:
        - [SimCLR paper](https://arxiv.org/pdf/2002.05709)

    Args:
        encoder (keras.Model): The encoder model.
        projector (keras.Model, optional): The projector model. Defaults to None.
    """
    if projector is None:
        projection_width = encoder.output_shape[-1]
        projector = keras.Sequential(
            [
                keras.layers.Dense(projection_width, activation="relu"),
                keras.layers.Dense(projection_width),
                keras.layers.BatchNormalization(),
            ],
            name="projector",
        )

    super().__init__(
        encoder=encoder,
        projector=projector,
        **kwargs,
    )

Functions