distiller
Distiller Trainer API
This module contains the implementation of a distiller trainer that can be used to train a student model
Classes:
-
Distiller
–A trainer for distillation
Classes
Distiller
Source code in neuralspot_edge/trainers/distiller.py
Functions
compile
compile(optimizer: keras.optimizers.Optimizer, metrics: list[keras.metrics.Metric], student_loss_fn: keras.losses.Loss, distillation_loss_fn: keras.losses.Loss, alpha: float = 0.1, temperature: float = 3)
Configure the distiller.
Parameters:
-
optimizer
(Optimizer
) –Keras optimizer for the student weights
-
metrics
(list[Metric]
) –Keras metrics for evaluation
-
student_loss_fn
(Loss
) –Loss function of difference between student predictions and ground-truth
-
distillation_loss_fn
(Loss
) –Loss function of difference between soft student predictions and soft teacher predictions
-
alpha
(float
, default:0.1
) –weight to student_loss_fn and 1-alpha to distillation_loss_fn. Defaults to 0.1.
-
temperature
(float
, default:3
) –Temperature for softening probability distributions. Defaults to 3.