Train ECG Segmentation Model¶
Date created: 2024/07/17
Last Modified: 2024/07/17
Description: Train, evaluate, and export 4-stage ECG segmentation model from scratch
Overview¶
In this guide, we will train a model to segment an ECG signal into four classes: NONE
, PWAVE
, QRS
, and TWAVE
. We will use both synthetic and real ECG datasets to train a TCN style model. We will also showcase evaluating and exporting the model for inference via TF Lite and TFLM.
Input
- Sensor: ECG
- Location: Wrist
- Sampling Rate: 100 Hz
- Frame Size: 2.56 seconds
Class Mapping
Segment ECG signal into one of the following classes:
Base Class | Target Class | Label |
---|---|---|
0-NONE | 0 | NONE |
1-PWAVE | 1 | PWAVE |
2-QRS | 2 | QRS |
3-TWAVE | 3 | TWAVE |
Datasets
- Synthetic: Synthetic ECG signals from PhysioKit
- LUDB: Lobachevsky University Electrocardiography database consists of 200 10-second 12-lead records. The boundaries and peaks of P, T waves and QRS complexes were manually annotated by cardiologists. Each record is annotated with the corresponding diagnosis.
#!pip install -q --disable-pip-version-check heartkit
Setup¶
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import IPython
import contextlib
from pathlib import Path
import tempfile
import keras
import heartkit as hk
import physiokit as pk
import numpy as np
import neuralspot_edge as nse
import matplotlib.pyplot as plt
import plotly.io as pio
# Be sure to set the dataset path to the correct location
datasets_dir = Path(os.getenv('HK_DATASET_PATH', './datasets'))
plot_theme = hk.utils.dark_theme
nse.utils.silence_tensorflow()
hk.utils.setup_plotting(plot_theme)
logger = nse.utils.setup_logger(__name__)
Target datasets¶
The only real-world public dataset containing ECG signals with annotated segments is the LUDB dataset. We will use this dataset to train our model. In addition, we will leverage the synthetic dataset provided by PhysioKit to increase amount of data to train on. We will apply several augmentation techniques to the synthetic dataset to increase the diversity of the data.
datasets = [
hk.NamedParams(
name="ecg-synthetic",
params=dict(
num_pts=5000,
params=dict(
presets=["SR", "AFIB", "ant_STEMI", "LAHB", "LPHB", "high_take_off", "LBBB", "random_morphology"],
preset_weights=[8, 4, 1, 1, 1, 1, 1, 0],
duration=10,
sample_rate=100,
heart_rate=[40, 160],
impedance=[1, 2],
p_multiplier=[0.7, 1.3],
t_multiplier=[0.7, 1.3],
noise_multiplier=[0, 0.01],
voltage_factor=[800, 1000]
)
)
),
hk.NamedParams(
name="ludb",
params=dict(
path=datasets_dir / "ludb",
)
)
]
Target classes¶
For this task, we are going to delineate ECG signals into one of four classes:
- None: Background signal
- P-Wave: Atrial depolarization
- QRS: Ventricular depolarization
- T-Wave: Ventricular repolarization
HeartKit already provides a number of heart segments. We will provide a class mapping for the four classes we are interested in. We will also provide class names for display purposes.
class_map = {
hk.tasks.HKSegment.normal: 0,
hk.tasks.HKSegment.pwave: 1,
hk.tasks.HKSegment.qrs: 2,
hk.tasks.HKSegment.twave: 3,
hk.tasks.HKSegment.uwave: 0,
hk.tasks.HKSegment.noise: 0
}
class_names=[
"NONE",
"P-WAVE",
"QRS",
"T-WAVE"
]
Define TCN model architecture¶
For this task, we are going to leverage a customized TCN model architecture that is smaller and can handle 1D signals. The model consists of 4 TCN blocks with a depth of 1. Each block leverages dilated depthwise-separable convolutions along with inverted expansion and squeeze and excitation layers. The model is followed by a 1D convolutional layer and a final dense layer for regression. Unlike vision tasks, we leverage larger kernel sizes and strides to capture temporal dependencies in the ECG signal.
architecture = hk.NamedParams(
name="tcn",
params=dict(
input_kernel=(1, 7),
input_norm="batch",
blocks=[
dict(depth=1, branch=1, filters=16, kernel=(1, 7), dilation=(1, 1), dropout=0.1, ex_ratio=1, se_ratio=0, norm="batch"),
dict(depth=1, branch=1, filters=24, kernel=(1, 7), dilation=(1, 2), dropout=0.1, ex_ratio=1, se_ratio=2, norm="batch"),
dict(depth=1, branch=1, filters=32, kernel=(1, 7), dilation=(1, 4), dropout=0.1, ex_ratio=1, se_ratio=2, norm="batch"),
dict(depth=1, branch=1, filters=48, kernel=(1, 7), dilation=(1, 8), dropout=0.1, ex_ratio=1, se_ratio=2, norm="batch")
],
output_kernel=(1, 7),
include_top=True,
use_logits=True,
model_name="tcn"
)
)
Preprocess pipeline¶
We will preprocess the ECG signals by applying the following steps:
- Apply bandpass filter with cutoff frequencies of 1Hz and 30Hz
- Apply Z-score normalization w/ epsilon to avoid division by zero
The task accepts a list of preprocessing functions that will be applied to the input data.
preprocesses = [hk.NamedParams(
name="layer_norm",
params=dict(
epsilon=0.01,
name="znorm"
)
)]
Augmentation pipeline¶
We will apply the following augmentations to the ECG signals:
- Baseline wander: Simulate baseline wander by adding a random frequency sinusoidal signal to the ECG signal
- Powerline noise: Simulate powerline noise by adding a 50 Hz sinusoidal signal to the ECG signal
- Burst noise: Simulate burst noise by randomly injecting burst of high frequency noise to the ECG signal
- Noise sources: Apply several noises at given frequencies to the ECG signal
- Lead noise: Simulate lead noise by adding a random frequency sinusoidal signal to the ECG signal
- NSTDB: Add real noise captured from NSTDB dataset to the ECG signal.
augmentations = [hk.NamedParams(
name="random_noise_distortion",
params=dict(
amplitude=[0, 0.5],
frequency=[0.5, 1.5],
name="baseline_wander"
)
), hk.NamedParams(
name="random_sine_wave",
params=dict(
amplitude=[0, 0.05],
frequency=[45, 50],
auto_vectorize=False,
name="powerline_noise"
)
), hk.NamedParams(
name="amplitude_warp",
params=dict(
amplitude=[0.9, 1.1],
frequency=[0.5, 1.5],
name="amplitude_warp"
)
), hk.NamedParams(
name="random_noise",
params=dict(
factor=[0, 0.025],
name="random_noise"
)
), hk.NamedParams(
name="random_background_noise",
params=dict(
amplitude=[0, 0.025],
num_noises=1,
name="nstdb"
)
)]
Task configuration¶
Here we provide the complete configuration for the task. This includes the dataset configuration, preprocessing pipeline, model architecture, and training parameters.
params = hk.HKTaskParams(
# Common arguments
name="hk-ecg-segmentation",
job_dir=Path(tempfile.gettempdir()) / "hk-ecg-segmentation",
# Dataset arguments
datasets=datasets,
# Signal arguments
sampling_rate=100,
frame_size=256,
# Dataloader arguments
samples_per_patient=25,
val_samples_per_patient=10,
test_samples_per_patient=10,
# Preprocessing/Augmentation arguments
preprocesses=preprocesses,
augmentations=augmentations,
# Class arguments
num_classes=len(class_names),
class_map=class_map,
class_names=class_names,
# Split arguments
val_patients=0.1,
val_size=20000,
test_size=20000,
val_file="val.pkl",
test_file="val.pkl",
# Model arguments
model_file="model.keras",
architecture=architecture,
# Training parameters
lr_rate=1e-3,
lr_cycles=1,
batch_size=256,
buffer_size=50000,
epochs=100,
steps_per_epoch=50,
val_metric="loss",
class_weights="balanced",
# Evaluation arguments
threshold=0.5,
val_metric_threshold=0.98,
# Export parameters
tflm_var_name="ecg_rhythm_flatbuffer",
tflm_file="ecg_rhythm_flatbuffer.h",
# Demo params
backend="pc",
demo_size=800,
display_report=False,
# Extra arguments
verbose=1,
seed=42
)
Load segmentation task¶
HeartKit provides a TaskFactory that includes a number ready-to-use tasks. Each task provides methods for training, evaluating, exporting, and demoing. We will grab the segmentation task and configure it for our use case.
task = hk.TaskFactory.get("segmentation")
Download the datasets¶
We will download the synthetic and LUDB datasets using the heartkit
library. If already downloaded, this step will be skipped.
task.download(params=params)
Visualize the data¶
Let's visualize a sample ECG signal from the synthetic dataset. Note this contains no noise or artifacts. Augmentations will be applied later to generate noisy samples for training.
ecg, segs, fids = pk.ecg.synthesize(
signal_length=params.frame_size,
sample_rate=params.sampling_rate,
heart_rate=60,
leads=1,
preset=pk.ecg.EcgPreset.SR,
noise_multiplier=0.0
)
ecg = ecg.squeeze()
segs = segs.squeeze()
pwaves = np.where(segs == hk.tasks.HKSegment.pwave, ecg, np.nan)
qrs = np.where(segs == hk.tasks.HKSegment.qrs, ecg, np.nan)
twaves = np.where(segs == hk.tasks.HKSegment.twave, ecg, np.nan)
ts = np.arange(0, len(ecg)) / params.sampling_rate
fig, ax = plt.subplots(1, 1, figsize=(9, 4))
plt.plot(ts, ecg, color=plot_theme.primary_color, lw=2, label="ECG")
plt.plot(ts, pwaves, color=plot_theme.secondary_color, lw=3, label="P-Wave")
plt.plot(ts, qrs, color=plot_theme.tertiary_color, lw=3, label="QRS")
plt.plot(ts, twaves, color=plot_theme.quaternary_color, lw=3, label="T-Wave")
plt.legend()
# Plot segments
plt.title("Synthetic ECG w/ Segments")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude")
plt.show()
Visualize the augmentations¶
Taking the existing synthetic ECG signal, let's look at the effects of the augmentations on the signal.
augmenter = hk.datasets.create_augmentation_pipeline(
augmentations=params.augmentations,
sampling_rate=params.sampling_rate,
)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1723838702.560559 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.580638 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.580743 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.581841 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.581918 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.581964 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.624118 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.624214 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723838702.624283 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
ecg_noise = augmenter(ecg.reshape(-1, 1)).numpy().squeeze()
pwaves = np.where(segs == hk.tasks.HKSegment.pwave, ecg_noise, np.nan)
qrs = np.where(segs == hk.tasks.HKSegment.qrs, ecg_noise, np.nan)
twaves = np.where(segs == hk.tasks.HKSegment.twave, ecg_noise, np.nan)
ts = np.arange(0, len(ecg_noise)) / params.sampling_rate
fig, ax = plt.subplots(1, 1, figsize=(9, 4))
plt.plot(ts, ecg_noise, color=plot_theme.primary_color, lw=2, label="ECG")
plt.plot(ts, pwaves, color=plot_theme.secondary_color, lw=3, label="P-Wave")
plt.plot(ts, qrs, color=plot_theme.tertiary_color, lw=3, label="QRS")
plt.plot(ts, twaves, color=plot_theme.quaternary_color, lw=3, label="T-Wave")
plt.legend()
# Plot segments
plt.title("Synthetic ECG w/ Noise")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude")
plt.show()
Visualize the model¶
Lets quickly instantiate and visualize the model.
model = nse.models.tcn.tcn_from_object(
x=keras.Input(shape=(params.frame_size, 1), name="inputs"),
params=architecture.params,
num_classes=len(class_names)
)
model.summary(layer_range=('inputs', model.layers[10].name))
Model: "TCN"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ inputs (InputLayer) │ (None, 256, 1) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ reshape (Reshape) │ (None, 1, 256, 1) │ 0 │ inputs[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ENC.CN │ (None, 1, 256, 1) │ 7 │ reshape[0][0] │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ENC.BN │ (None, 1, 256, 1) │ 4 │ ENC.CN[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.D1.DW.B1.CN │ (None, 1, 256, 1) │ 7 │ ENC.BN[0][0] │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.D1.DW.B1.BN │ (None, 1, 256, 1) │ 4 │ B1.D1.DW.B1.CN[0… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.D1.DW.ACT │ (None, 1, 256, 1) │ 0 │ B1.D1.DW.B1.BN[0… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.D1.PW.B1.CN │ (None, 1, 256, │ 16 │ B1.D1.DW.ACT[0][… │ │ (Conv2D) │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.D1.PW.B1.BN │ (None, 1, 256, │ 64 │ B1.D1.PW.B1.CN[0… │ │ (BatchNormalizatio… │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.D1.PW.ACT │ (None, 1, 256, │ 0 │ B1.D1.PW.B1.BN[0… │ │ (Activation) │ 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1.DROP │ (None, 1, 256, │ 0 │ B1.D1.PW.ACT[0][… │ │ (SpatialDropout2D) │ 16) │ │ │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 7,310 (28.55 KB)
Trainable params: 6,922 (27.04 KB)
Non-trainable params: 388 (1.52 KB)
Train the model¶
Using the task configuration, we will train the model on the synthetic and LUDB datasets. We will also apply augmentations to the synthetic dataset to increase the diversity of the data.
task.train(params)
INFO Creating synthetic dataset cache with 5000 patients ecg_synthetic.py:159
Building ecg-synthetic cache: 100%|██████████| 5000/5000 [00:56<00:00, 87.92it/s]
INFO Validation steps per epoch: 78 datasets.py:107
Training: 0%| 0/100 ETA: ?s, ?epochs/sWARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1723838774.747244 758430 service.cc:146] XLA service 0x7b7988001e40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1723838774.747268 758430 service.cc:154] StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9 I0000 00:00:1723838783.452909 758430 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. Training: 100%|██████████ 100/100 ETA: 00:00s, 1.96s/epochs
78/78 ━━━━━━━━━━━━━━━━━━━━ 1s 743us/step 78/78 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - acc: 0.8512 - f1: 0.8524 - loss: 0.1304
Model evaluation¶
Now that we have trained the model, we will evaluate the model on the test dataset. Similar to training, we will provide the high-level configuration to the task process.
task.evaluate(params)
INFO Creating synthetic dataset cache with 5000 patients ecg_synthetic.py:159
Building ecg-synthetic cache: 100%|██████████| 5000/5000 [00:57<00:00, 87.27it/s]
78/78 ━━━━━━━━━━━━━━━━━━━━ 1s 970us/step - acc: 0.8674 - f1: 0.8687 - loss: 0.1110
INFO [TEST SET] ACC=0.8652, F1=0.8665, LOSS=0.1141 evaluate.py:47
78/78 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step
Confusion matrix¶
Let's visualize the confusion matrix to understand the model's performance on each class.
IPython.display.Image(filename=params.job_dir / "confusion_matrix_test.png", width=500)
Export model to TF Lite / TFLM¶
Once we have trained and evaluated the model, we need to export the model into a format that can be used for inference on the edge. Currently, we export the model to TensorFlow Lite flatbuffer format. This will also generate a C header file that can be used with TensorFlow Lite for Microcontrollers (TFLM).
Post-Training Quantization (PTQ)¶
For running on bare metal, we will perform post-training quantization to convert the model to an 8-bit integer model. The weights and activations will be quantized to 8-bits and biases will be quantized to 32-bits. This will reduce the model size and improve the inference speed.
quantization = hk.QuantizationParams(
enabled=True,
format="INT8",
io_type="int8",
conversion="CONCRETE",
)
params.quantization = quantization
# TF dumps a lot of info to stdout, so we redirect it to /dev/null
with open(os.devnull, 'w') as devnull:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
task.export(params)
INFO Creating synthetic dataset cache with 5000 patients ecg_synthetic.py:159
I0000 00:00:1723839092.566023 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723839092.566111 758191 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1 I0000 00:00:1723839092.566407 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723839092.566464 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723839092.566510 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723839092.566580 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723839092.566627 758191 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 W0000 00:00:1723839092.671832 758191 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format. W0000 00:00:1723839092.671846 758191 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: INT8, output_inference_type: INT8 INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Run inference demo¶
We will run a demo on the PC to verify that the model is working as expected. The demo will load the model and run inferences across a randomly selected ECG signal. The demo will also provide the model's prediction and the corresponding class name.
task.demo(params=params)
Inference: 100%|██████████| 4/4 [00:00<00:00, 5.55it/s]