Skip to content

mask_autoencoder

Masked Autoencoder Trainer API

This module contains the implementation of a masked autoencoder trainer that can be used to train a model using the masked autoencoder approach.

Classes:

Classes

MaskedAutoencoder

MaskedAutoencoder(patch_layer: Callable[[keras.KerasTensor], tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]], patch_encoder: Callable[[keras.KerasTensor], keras.KerasTensor], encoder: keras.Model, decoder: keras.Model, **kwargs)

Masked Autoencoder model for self-supervised learning.

Parameters:

  • patch_layer (Callable[[KerasTensor], tuple[KerasTensor, KerasTensor, KerasTensor, KerasTensor, KerasTensor]]) –

    The patch layer which will extract patches from the input.

  • patch_encoder (Callable[[KerasTensor], KerasTensor]) –

    The patch encoder which will encode the patches.

  • encoder (Model) –

    The encoder model.

  • decoder (Model) –

    The decoder model.

Source code in neuralspot_edge/trainers/mask_autoencoder.py
def __init__(
    self,
    patch_layer: Callable[
        [keras.KerasTensor],
        tuple[
            keras.KerasTensor,
            keras.KerasTensor,
            keras.KerasTensor,
            keras.KerasTensor,
            keras.KerasTensor,
        ],
    ],
    patch_encoder: Callable[[keras.KerasTensor], keras.KerasTensor],
    encoder: keras.Model,
    decoder: keras.Model,
    **kwargs,
):
    """Masked Autoencoder model for self-supervised learning.

    Args:
        patch_layer (Callable[[keras.KerasTensor], tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]]): The patch layer which will extract patches from the input.
        patch_encoder (Callable[[keras.KerasTensor], keras.KerasTensor]): The patch encoder which will encode the patches.
        encoder (keras.Model): The encoder model.
        decoder (keras.Model): The decoder model.
    """
    super().__init__(**kwargs)
    self.patch_layer = patch_layer
    self.patch_encoder = patch_encoder
    self.encoder = encoder
    self.decoder = decoder

Functions

calculate_loss
calculate_loss(x: keras.KerasTensor, test: bool = False)

Calculate the loss for the Masked Autoencoder model.

Parameters:

  • x (KerasTensor) –

    The input tensor.

  • test (bool, default: False ) –

    Whether the model is testing. Defaults to False.

Source code in neuralspot_edge/trainers/mask_autoencoder.py
def calculate_loss(self, x: keras.KerasTensor, test: bool = False):
    """Calculate the loss for the Masked Autoencoder model.

    Args:
        x (keras.KerasTensor): The input tensor.
        test (bool, optional): Whether the model is testing. Defaults to False.
    """
    # Patch the input.
    patches = self.patch_layer(x)

    # Encode the patches.
    (
        unmasked_embeddings,
        masked_embeddings,
        unmasked_positions,
        mask_indices,
        unmask_indices,
    ) = self.patch_encoder(patches)

    # Pass the unmasked patches to the encoder.
    encoder_outputs = self.encoder(unmasked_embeddings)

    # Create the decoder inputs.
    encoder_outputs = encoder_outputs + unmasked_positions
    decoder_inputs = keras.ops.concatenate([encoder_outputs, masked_embeddings], axis=1)

    # Decode the inputs.
    decoder_outputs = self.decoder(decoder_inputs)
    decoder_patches = self.patch_layer(decoder_outputs)

    loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)
    loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)

    # Compute the total loss.
    total_loss = self.compute_loss(y=loss_patch, y_pred=loss_output)

    return total_loss, loss_patch, loss_output