Skip to content

random_choice

Random Choice Layer API

This module provides classes to build random choice layers.

Classes:

Classes

RandomChoice

RandomChoice(layers: list[BaseAugmentation], batchwise: bool = False, **kwargs)

Randomly choose one augmentation layer from a list of augmentation layers.

Parameters:

  • layers (list[BaseAugmentation]) –

    List of augmentation layers to choose from.

  • batchwise (bool, default: False ) –

    If True, apply same layer to all samples in the batch.

Source code in neuralspot_edge/layers/preprocessing/random_choice.py
def __init__(self, layers: list[BaseAugmentation], batchwise: bool = False, **kwargs):
    """Randomly choose one augmentation layer from a list of augmentation layers.

    Args:
        layers (list[BaseAugmentation]): List of augmentation layers to choose from.
        batchwise (bool): If True, apply same layer to all samples in the batch.
    """
    super().__init__(**kwargs)
    self.layers = layers
    self.batchwise = batchwise

Functions

batch_augment
batch_augment(inputs)

Apply random layer(s) to the batch

Source code in neuralspot_edge/layers/preprocessing/random_choice.py
def batch_augment(self, inputs):
    """Apply random layer(s) to the batch"""
    # If batchwise, apply the same layer to all samples in the batch
    if self.batchwise:
        lyr_idx: int = keras.random.randint(
            shape=(), minval=0, maxval=len(self.layers), dtype="int32", seed=self.random_generator
        )
        branch_fns = [lambda x: layer.call(x, training=self.training) for layer in self.layers]
        return keras.ops.switch(lyr_idx, branch_fns, inputs)
    raise NotImplementedError("Batchwise=False is not implemented yet")
get_config
get_config()

Serializes the configuration of the layer.

Source code in neuralspot_edge/layers/preprocessing/random_choice.py
def get_config(self):
    """Serializes the configuration of the layer."""
    config = super().get_config()
    config.update(
        {
            "layers": self.layers,
            "batchwise": self.batchwise,
        }
    )
    return config