Masked Autoencoder Trainer API
This module contains the implementation of a masked autoencoder trainer that
can be used to train a model using the masked autoencoder approach.
Classes:
Classes
MaskedAutoencoder
MaskedAutoencoder(patch_layer: Callable[[keras.KerasTensor], tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]], patch_encoder: Callable[[keras.KerasTensor], keras.KerasTensor], encoder: keras.Model, decoder: keras.Model, **kwargs)
Masked Autoencoder model for self-supervised learning.
Parameters:
-
patch_layer
(Callable[[KerasTensor], tuple[KerasTensor, KerasTensor, KerasTensor, KerasTensor, KerasTensor]]
)
–
The patch layer which will extract patches from the input.
-
patch_encoder
(Callable[[KerasTensor], KerasTensor]
)
–
The patch encoder which will encode the patches.
-
encoder
(Model
)
–
-
decoder
(Model
)
–
Source code in neuralspot_edge/trainers/mask_autoencoder.py
| def __init__(
self,
patch_layer: Callable[
[keras.KerasTensor],
tuple[
keras.KerasTensor,
keras.KerasTensor,
keras.KerasTensor,
keras.KerasTensor,
keras.KerasTensor,
],
],
patch_encoder: Callable[[keras.KerasTensor], keras.KerasTensor],
encoder: keras.Model,
decoder: keras.Model,
**kwargs,
):
"""Masked Autoencoder model for self-supervised learning.
Args:
patch_layer (Callable[[keras.KerasTensor], tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]]): The patch layer which will extract patches from the input.
patch_encoder (Callable[[keras.KerasTensor], keras.KerasTensor]): The patch encoder which will encode the patches.
encoder (keras.Model): The encoder model.
decoder (keras.Model): The decoder model.
"""
super().__init__(**kwargs)
self.patch_layer = patch_layer
self.patch_encoder = patch_encoder
self.encoder = encoder
self.decoder = decoder
|
Functions
calculate_loss
calculate_loss(x: keras.KerasTensor, test: bool = False)
Calculate the loss for the Masked Autoencoder model.
Parameters:
-
x
(KerasTensor
)
–
-
test
(bool
, default:
False
)
–
Whether the model is testing. Defaults to False.
Source code in neuralspot_edge/trainers/mask_autoencoder.py
| def calculate_loss(self, x: keras.KerasTensor, test: bool = False):
"""Calculate the loss for the Masked Autoencoder model.
Args:
x (keras.KerasTensor): The input tensor.
test (bool, optional): Whether the model is testing. Defaults to False.
"""
# Patch the input.
patches = self.patch_layer(x)
# Encode the patches.
(
unmasked_embeddings,
masked_embeddings,
unmasked_positions,
mask_indices,
unmask_indices,
) = self.patch_encoder(patches)
# Pass the unmasked patches to the encoder.
encoder_outputs = self.encoder(unmasked_embeddings)
# Create the decoder inputs.
encoder_outputs = encoder_outputs + unmasked_positions
decoder_inputs = keras.ops.concatenate([encoder_outputs, masked_embeddings], axis=1)
# Decode the inputs.
decoder_outputs = self.decoder(decoder_inputs)
decoder_patches = self.patch_layer(decoder_outputs)
loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)
loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)
# Compute the total loss.
total_loss = self.compute_loss(y=loss_patch, y_pred=loss_output)
return total_loss, loss_patch, loss_output
|