Skip to content

Model Training

Introduction

Each task provides a mode to train a model on the specified datasets and dataloaders. The training mode can be invoked either via CLI or within heartkit python package. At a high level, the training mode performs the following actions based on the provided configuration parameters:

  1. Load the configuration parameters (e.g. configuration.json (1))
  2. Load the desired datasets (e.g. PtbxlDataset)
  3. Load the corresponding task dataloaders (e.g. PtbxlDataLoader)
  4. Initialize custom model architecture (e.g. tcn)
  5. Define the metrics, loss, and optimizer (e.g. accuracy, categorical_crossentropy, adam)
  6. Train the model (e.g. model.fit)
  7. Save artifacts (e.g. model.keras)
  1. Example configuration:
    {
        "name": "arr-2-eff-sm",
        "project": "hk-rhythm-2",
        "job_dir": "./results/arr-2-eff-sm",
        "verbose": 2,
        "datasets": [
            {
                "name": "ptbxl",
                "params": {
                    "path": "./datasets/ptbxl"
                }
            }
        ],
        "num_classes": 2,
        "class_map": {
            "0": 0,
            "7": 1,
            "8": 1
        },
        "class_names": [
            "NORMAL",
            "AFIB/AFL"
        ],
        "class_weights": "balanced",
        "sampling_rate": 100,
        "frame_size": 512,
        "samples_per_patient": [
            10,
            10
        ],
        "val_samples_per_patient": [
            5,
            5
        ],
        "test_samples_per_patient": [
            5,
            5
        ],
        "val_patients": 0.2,
        "val_size": 20000,
        "test_size": 20000,
        "batch_size": 256,
        "buffer_size": 20000,
        "epochs": 100,
        "steps_per_epoch": 50,
        "val_metric": "loss",
        "lr_rate": 0.001,
        "lr_cycles": 1,
        "threshold": 0.75,
        "val_metric_threshold": 0.98,
        "tflm_var_name": "g_rhythm_model",
        "tflm_file": "rhythm_model_buffer.h",
        "backend": "pc",
        "demo_size": 896,
        "display_report": true,
        "quantization": {
            "qat": false,
            "format": "INT8",
            "io_type": "int8",
            "conversion": "CONCRETE",
            "debug": false
        },
        "preprocesses": [
            {
                "name": "layer_norm",
                "params": {
                    "epsilon": 0.01,
                    "name": "znorm"
                }
            }
        ],
        "augmentations": [],
        "model_file": "model.keras",
        "use_logits": false,
        "architecture": {
            "name": "efficientnetv2",
            "params": {
                "input_filters": 16,
                "input_kernel_size": [
                    1,
                    9
                ],
                "input_strides": [
                    1,
                    2
                ],
                "blocks": [
                    {
                        "filters": 24,
                        "depth": 2,
                        "kernel_size": [
                            1,
                            9
                        ],
                        "strides": [
                            1,
                            2
                        ],
                        "ex_ratio": 1,
                        "se_ratio": 2
                    },
                    {
                        "filters": 32,
                        "depth": 2,
                        "kernel_size": [
                            1,
                            9
                        ],
                        "strides": [
                            1,
                            2
                        ],
                        "ex_ratio": 1,
                        "se_ratio": 2
                    },
                    {
                        "filters": 40,
                        "depth": 2,
                        "kernel_size": [
                            1,
                            9
                        ],
                        "strides": [
                            1,
                            2
                        ],
                        "ex_ratio": 1,
                        "se_ratio": 2
                    },
                    {
                        "filters": 48,
                        "depth": 1,
                        "kernel_size": [
                            1,
                            9
                        ],
                        "strides": [
                            1,
                            2
                        ],
                        "ex_ratio": 1,
                        "se_ratio": 2
                    }
                ],
                "output_filters": 0,
                "include_top": true,
                "use_logits": true
            }
        }
    }
    


graph LR
A("`Load
configuration
__HKTaskParams__
`")
B("`Load
datasets
__DatasetFactory__
`")
C("`Load
dataloaders
__DataLoaderFactory__
`")
D("`Initialize
model
__ModelFactory__
`")
E("`Define
_metrics_, _loss_,
_optimizer_
`")
F("`Train
__model__
`")
G("`Save
__artifacts__
`")
A ==> B
subgraph "Preprocess"
    B ==> C
end
subgraph "Model Training"
    C ==> D
    D ==> E
    E ==> F
end
F ==> G

Usage

CLI

The following command will train a rhythm model using the reference configuration.

heartkit --task rhythm --mode train --config ./configuration.json

Python

The model can be trained using the following snippet:

1
2
3
4
5
task = hk.TaskFactory.get("rhythm")

params = hk.HKTaskParams(...)  # (1)

task.train(params)
  1. Example configuration:
    hk.HKTaskParams(
        name="arr-2-eff-sm",
        project="hk-rhythm-2",
        job_dir="./results/arr-2-eff-sm",
        verbose=2,
        datasets=[hk.NamedParams(
            name="ptbxl",
            params=dict(
                path="./datasets/ptbxl"
            )
        )],
        num_classes=2,
        class_map={
            "0": 0,
            "7": 1,
            "8": 1
        },
        class_names=[
            "NORMAL", "AFIB/AFL"
        ],
        class_weights="balanced",
        sampling_rate=100,
        frame_size=512,
        samples_per_patient=[10, 10],
        val_samples_per_patient=[5, 5],
        test_samples_per_patient=[5, 5],
        val_patients=0.20,
        val_size=20000,
        test_size=20000,
        batch_size=256,
        buffer_size=20000,
        epochs=100,
        steps_per_epoch=50,
        val_metric="loss",
        lr_rate=1e-3,
        lr_cycles=1,
        threshold=0.75,
        val_metric_threshold=0.98,
        tflm_var_name="g_rhythm_model",
        tflm_file="rhythm_model_buffer.h",
        backend="pc",
        demo_size=896,
        display_report=True,
        quantization=hk.QuantizationParams(
            qat=False,
            format="INT8",
            io_type="int8",
            conversion="CONCRETE",
            debug=False
        ),
        preprocesses=[
            hk.NamedParams(
                name="layer_norm",
                params=dict(
                    epsilon=0.01,
                    name="znorm"
                )
            )
        ],
        augmentations=[
        ],
        model_file="model.keras",
        use_logits=False,
        architecture=hk.NamedParams(
            name="efficientnetv2",
            params=dict(
                input_filters=16,
                input_kernel_size=[1, 9],
                input_strides=[1, 2],
                blocks=[
                    {"filters": 24, "depth": 2, "kernel_size": [1, 9], "strides": [1, 2], "ex_ratio": 1,  "se_ratio": 2},
                    {"filters": 32, "depth": 2, "kernel_size": [1, 9], "strides": [1, 2], "ex_ratio": 1,  "se_ratio": 2},
                    {"filters": 40, "depth": 2, "kernel_size": [1, 9], "strides": [1, 2], "ex_ratio": 1,  "se_ratio": 2},
                    {"filters": 48, "depth": 1, "kernel_size": [1, 9], "strides": [1, 2], "ex_ratio": 1,  "se_ratio": 2}
                ],
                output_filters=0,
                include_top=True,
                use_logits=True
            )
        }
    )
    

Arguments

Please refer to HKTaskParams for the list of arguments that can be used with the train command.


Logging

HeartKit provides built-in support for logging to several third-party services including Weights & Biases (WANDB) and TensorBoard.

WANDB

The training mode is able to log all metrics and artifacts (aka models) to Weights & Biases (WANDB). To enable WANDB logging, simply set environment variable WANDB=1. Remember to sign in prior to running experiments by running wandb login.

TensorBoard

The training mode is able to log all metrics to TensorBoard. To enable TensorBoard logging, simply set environment variable TENSORBOARD=1.