Skip to content

Model Training

Introduction

Each task provides a mode to train a model on the specified datasets. The training mode can be invoked either via CLI or within heartkit python package. At a high level, the training mode performs the following actions based on the provided configuration parameters:

  1. Load the configuration data (e.g. rhythm-class-2.json)
  2. Load the desired datasets (e.g. icentia11k)
  3. Load the custom model architecture (e.g. tcn)
  4. Train the model
  5. Save the trained model
  6. Generate training report

Usage

Example

The following command will train a rhythm model using the reference configuration:

heartkit --task rhythm --mode train --config ./configs/rhythm-class-2.json
from pathlib import Path
import heartkit as hk

task = hk.TaskFactory.get("rhythm")
task.train(hk.HKTrainParams(
    job_dir=Path("./results/rhythm-class-2"),
    ds_path=Path("./datasets"),
    datasets=[{
        "name": "icentia11k",
        "params": {}
    }],
    num_classes=2,
    class_map={
        0: 0,
        1: 1,
        2: 1
    },
    class_names=[
        "NONE", "AFIB/AFL"
    ],
    sampling_rate=200,
    frame_size=800,
    samples_per_patient=[100, 800],
    val_samples_per_patient=[100, 800],
    data_parallelism=lambda: os.cpu_count() or 1,
    preprocesses=[
        hk.PreprocessParams(
            name="znorm",
            params=dict(
                eps=0.01,
                axis=None
            )
        )
    ]
))

Arguments

The following tables lists the parameters that can be used to configure the training mode.

HKTrainParams

Argument Type Opt/Req Default Description
name str Optional "experiment" Experiment name
job_dir Path Optional tempfile.gettempdir Job output directory
ds_path Path Optional Path() Dataset directory
datasets list[DatasetParams] Optional Datasets
sampling_rate int Optional 250 Target sampling rate (Hz)
frame_size int Optional 1250 Frame size
num_classes int Optional 1 # of classes
class_map dict[int, int] Optional Class/label mapping
class_names list[str] Optional None Class names
samples_per_patient int|list[int] Optional 1000 # train samples per patient
val_samples_per_patient int|list[int] Optional 1000 # validation samples per patient
train_patients float|None Optional None # or proportion of patients for training
val_patients float|None Optional None # or proportion of patients for validation
val_file Path|None Optional None Path to load/store pickled validation file
val_size int|None Optional None # samples for validation
resume bool Optional False Resume training
architecture ModelArchitecture|None Optional None Custom model architecture
model_file Path|None Optional None Path to save model file (.keras)
weights_file Path|None Optional None Path to a checkpoint weights to load
quantization QuantizationParams Optional Quantization parameters
lr_rate float Optional 1e-3 Learning rate
lr_cycles int Optional 3 Number of learning rate cycles
lr_decay float Optional 0.9 Learning rate decay
class_weights Literal["balanced", "fixed"] Optional "fixed" Class weights
batch_size int Optional 32 Batch size
buffer_size int Optional 100 Buffer size
epochs int Optional 50 Number of epochs
steps_per_epoch int Optional 10 Number of steps per epoch
val_metric Literal["loss", "acc", "f1"] Optional "loss" Performance metric
preprocesses list[PreprocessParams] Optional [] Preprocesses
augmentations list[AugmentationParams] Optional [] Augmentations
seed int|None Optional None Random state seed
data_parallelism int Optional os.cpu_count() or 1 # of data loaders running in parallel

QuantizationParams

Argument Type Opt/Req Default Description
enabled bool Optional False Enable quantization
qat bool Optional False Enable quantization aware training (QAT)
ptq bool Optional False Enable post training quantization (PTQ)
input_type str|None Optional None Input type
output_type str|None Optional None Output type
supported_ops list[str]|None Optional None Supported ops

DatasetParams

Argument Type Opt/Req Default Description
name str Required Dataset name
params dict[str, Any] Optional {} Dataset parameters
weight float Optional 1 Dataset weight

PreprocessParams

Argument Type Opt/Req Default Description
name str Required Preprocess name
params dict[str, Any] Optional {} Preprocess parameters

AugmentationParams

Argument Type Opt/Req Default Description
name str Required Augmentation name
params dict[str, Any] Optional {} Augmentation parameters

Logging

HeartKit provides built-in support for logging to several third-party services including Weights & Biases (WANDB) and TensorBoard.

WANDB

The training mode is able to log all metrics and artifacts (aka models) to Weights & Biases (WANDB). To enable WANDB logging, simply set environment variable WANDB=1. Remember to sign in prior to running experiments by running wandb login.

TensorBoard

The training mode is able to log all metrics to TensorBoard. To enable TensorBoard logging, simply set environment variable TENSORBOARD=1.