Model Training
Introduction
Each task provides a mode to train a model on the specified datasets and dataloaders. The training mode can be invoked either via CLI or within heartkit python package. At a high level, the training mode performs the following actions based on the provided configuration parameters:
- Load the configuration parameters (e.g.
configuration.json(1)) - Load the desired datasets (e.g.
PtbxlDataset) - Load the corresponding task dataloaders (e.g.
PtbxlDataLoader) - Initialize custom model architecture (e.g.
tcn) - Define the metrics, loss, and optimizer (e.g.
accuracy,categorical_crossentropy,adam) - Train the model (e.g.
model.fit) - Save artifacts (e.g.
model.keras)
- Example configuration:
{ "name": "arr-2-eff-sm", "project": "hk-rhythm-2", "job_dir": "./results/arr-2-eff-sm", "verbose": 2, "datasets": [ { "name": "ptbxl", "params": { "path": "./datasets/ptbxl" } } ], "num_classes": 2, "class_map": { "0": 0, "7": 1, "8": 1 }, "class_names": [ "NORMAL", "AFIB/AFL" ], "class_weights": "balanced", "sampling_rate": 100, "frame_size": 512, "samples_per_patient": [ 10, 10 ], "val_samples_per_patient": [ 5, 5 ], "test_samples_per_patient": [ 5, 5 ], "val_patients": 0.2, "val_size": 20000, "test_size": 20000, "batch_size": 256, "buffer_size": 20000, "epochs": 100, "steps_per_epoch": 50, "val_metric": "loss", "lr_rate": 0.001, "lr_cycles": 1, "threshold": 0.75, "val_metric_threshold": 0.98, "tflm_var_name": "g_rhythm_model", "tflm_file": "rhythm_model_buffer.h", "backend": "pc", "demo_size": 896, "display_report": true, "quantization": { "qat": false, "format": "INT8", "io_type": "int8", "conversion": "CONCRETE", "debug": false }, "preprocesses": [ { "name": "layer_norm", "params": { "epsilon": 0.01, "name": "znorm" } } ], "augmentations": [], "model_file": "model.keras", "use_logits": false, "architecture": { "name": "efficientnetv2", "params": { "input_filters": 16, "input_kernel_size": [ 1, 9 ], "input_strides": [ 1, 2 ], "blocks": [ { "filters": 24, "depth": 2, "kernel_size": [ 1, 9 ], "strides": [ 1, 2 ], "ex_ratio": 1, "se_ratio": 2 }, { "filters": 32, "depth": 2, "kernel_size": [ 1, 9 ], "strides": [ 1, 2 ], "ex_ratio": 1, "se_ratio": 2 }, { "filters": 40, "depth": 2, "kernel_size": [ 1, 9 ], "strides": [ 1, 2 ], "ex_ratio": 1, "se_ratio": 2 }, { "filters": 48, "depth": 1, "kernel_size": [ 1, 9 ], "strides": [ 1, 2 ], "ex_ratio": 1, "se_ratio": 2 } ], "output_filters": 0, "include_top": true, "use_logits": true } } }
graph LR
A("`Load
configuration
__HKTaskParams__
`")
B("`Load
datasets
__DatasetFactory__
`")
C("`Load
dataloaders
__DataLoaderFactory__
`")
D("`Initialize
model
__ModelFactory__
`")
E("`Define
_metrics_, _loss_,
_optimizer_
`")
F("`Train
__model__
`")
G("`Save
__artifacts__
`")
A ==> B
subgraph "Preprocess"
B ==> C
end
subgraph "Model Training"
C ==> D
D ==> E
E ==> F
end
F ==> G
Usage
CLI
The following command will train a rhythm model using the reference configuration.
Python
The model can be trained using the following snippet:
- Example configuration:
Arguments
Please refer to HKTaskParams for the list of arguments that can be used with the train command.
Logging
HeartKit provides built-in support for logging to several third-party services including Weights & Biases (WANDB) and TensorBoard.
WANDB
The training mode is able to log all metrics and artifacts (aka models) to Weights & Biases (WANDB). To enable WANDB logging, simply set environment variable WANDB=1. Remember to sign in prior to running experiments by running wandb login.
TensorBoard
The training mode is able to log all metrics to TensorBoard. To enable TensorBoard logging, simply set environment variable TENSORBOARD=1.