Skip to content

distiller

Distiller Trainer API

This module contains the implementation of a distiller trainer that can be used to train a student model

Classes:

Classes

Distiller

Distiller(student: keras.models.Model, teacher: keras.models.Model)
Source code in neuralspot_edge/trainers/distiller.py
def __init__(self, student: keras.models.Model, teacher: keras.models.Model):
    super().__init__()
    self.teacher = teacher
    self.student = student

Functions

compile
compile(optimizer: keras.optimizers.Optimizer, metrics: list[keras.metrics.Metric], student_loss_fn: keras.losses.Loss, distillation_loss_fn: keras.losses.Loss, alpha: float = 0.1, temperature: float = 3)

Configure the distiller.

Parameters:

  • optimizer (Optimizer) –

    Keras optimizer for the student weights

  • metrics (list[Metric]) –

    Keras metrics for evaluation

  • student_loss_fn (Loss) –

    Loss function of difference between student predictions and ground-truth

  • distillation_loss_fn (Loss) –

    Loss function of difference between soft student predictions and soft teacher predictions

  • alpha (float, default: 0.1 ) –

    weight to student_loss_fn and 1-alpha to distillation_loss_fn. Defaults to 0.1.

  • temperature (float, default: 3 ) –

    Temperature for softening probability distributions. Defaults to 3.

Source code in neuralspot_edge/trainers/distiller.py
def compile(
    self,
    optimizer: keras.optimizers.Optimizer,
    metrics: list[keras.metrics.Metric],
    student_loss_fn: keras.losses.Loss,
    distillation_loss_fn: keras.losses.Loss,
    alpha: float = 0.1,
    temperature: float = 3,
):
    """Configure the distiller.

    Args:
        optimizer (keras.optimizers.Optimizer): Keras optimizer for the student weights
        metrics (list[keras.metrics.Metric]): Keras metrics for evaluation
        student_loss_fn (keras.losses.Loss): Loss function of difference between student
            predictions and ground-truth
        distillation_loss_fn (keras.losses.Loss): Loss function of difference between soft
            student predictions and soft teacher predictions
        alpha (float, optional): weight to student_loss_fn and 1-alpha to distillation_loss_fn. Defaults to 0.1.
        temperature (float, optional): Temperature for softening probability distributions. Defaults to 3.
    """
    super().compile(optimizer=optimizer, metrics=metrics)
    self.student_loss_fn = student_loss_fn
    self.distillation_loss_fn = distillation_loss_fn
    self.alpha = alpha
    self.temperature = temperature