distiller
Distiller Trainer API
This module contains the implementation of a distiller trainer that can be used to train a student model
Classes:
-
Distiller–A trainer for distillation
Classes
Distiller
Source code in neuralspot_edge/trainers/distiller.py
Functions
compile
compile(optimizer: keras.optimizers.Optimizer, metrics: list[keras.metrics.Metric], student_loss_fn: keras.losses.Loss, distillation_loss_fn: keras.losses.Loss, alpha: float = 0.1, temperature: float = 3)
Configure the distiller.
Parameters:
-
(optimizerOptimizer) –Keras optimizer for the student weights
-
(metricslist[Metric]) –Keras metrics for evaluation
-
(student_loss_fnLoss) –Loss function of difference between student predictions and ground-truth
-
(distillation_loss_fnLoss) –Loss function of difference between soft student predictions and soft teacher predictions
-
(alphafloat, default:0.1) –weight to student_loss_fn and 1-alpha to distillation_loss_fn. Defaults to 0.1.
-
(temperaturefloat, default:3) –Temperature for softening probability distributions. Defaults to 3.