{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
\n", "\n", "- \n", "\n", " View in Colab\n", "\n", "\n", "- \n", "\n", " GitHub source\n", "\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train ECG Denosier\n", "\n", "__Date created:__ 2024/08/13 \n", "\n", "__Last Modified:__ 2024/07/17 \n", "\n", "__Description:__ Train, evaluate, and export ECG denoiser model from scratch\n", "\n", "\n", "## Overview \n", "\n", "In this guide, we will train an ECG denoiser to remove noise and artifacts from raw ECG signals. \n", "Once trained, we demonstrate how to evaluate the model and export it for inference for both TF Lite and TF Lite for Micro.\n", "\n", "__Input__\n", "\n", "- **Sensor**: ECG \n", "- **Location**: Wrist\n", "- **Sampling Rate**: 100 Hz\n", "- **Frame Size**: 2.56 seconds\n", "\n", "__Datasets__\n", "\n", "- **[Synthetic](https://ambiqai.github.io/heartkit/datasets/synthetic/)**: Synthetic ECG signals from PhysioKit\n", "- **[PTB-XL](https://ambiqai.github.io/heartkit/datasets/ptbxl/)**: The PTB-XL is a large publicly available electrocardiography dataset. \n", "It contains 21837 clinical 12-lead ECGs from 18885 patients of 10 second length. The ECGs are sampled at 500 Hz and are annotated by up to two cardiologists.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install -q --disable-pip-version-check heartkit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n", "import contextlib\n", "from pathlib import Path\n", "import tempfile\n", "import keras\n", "import heartkit as hk\n", "import numpy as np\n", "import neuralspot_edge as nse\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Be sure to set the dataset path to the correct location\n", "datasets_dir = Path(os.getenv('HK_DATASET_PATH', './datasets'))\n", "\n", "plot_theme = hk.utils.dark_theme\n", "nse.utils.silence_tensorflow()\n", "hk.utils.setup_plotting(plot_theme)\n", "logger = nse.utils.setup_logger(__name__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create preprocess/augmentation pipeline\n", "\n", "Since our goal is to denoise ECG signals, we need to create an augmentation pipeline to generate noisy samples. \n", "\n", "We will leverage `neuralspot-edge` preprocessing layers to create the following augmentations:\n", "\n", "* Baseline wander: Simulate baseline wander by adding a low frequency sine signal\n", "* Powerline noise: Simulate powerline noise by adding a 50 Hz sinusoidal signal \n", "* Amplitude warp: Simulate amplitude warp by randomly scaling along a low frequency sine wave\n", "* Gaussian noise: Simulate lead noise by adding random noise following a Gaussian distribution\n", "* Background noise: Add real noise captured from NSTDB dataset\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "preprocesses = [hk.NamedParams(\n", " name=\"layer_norm\",\n", " params=dict(\n", " epsilon=0.01\n", " )\n", ")]\n", "\n", "augmentations = [hk.NamedParams(\n", " name=\"random_noise_distortion\",\n", " params=dict(\n", " amplitude=[0.1, 1.5],\n", " frequency=[0.5, 1.5],\n", " name=\"baseline_wander\"\n", " )\n", "), hk.NamedParams(\n", " name=\"random_sine_wave\",\n", " params=dict(\n", " amplitude=[0, 0.05],\n", " frequency=[45, 50],\n", " auto_vectorize=False,\n", " name=\"powerline_noise\"\n", " )\n", "), hk.NamedParams(\n", " name=\"amplitude_warp\",\n", " params=dict(\n", " amplitude=[0.9, 1.1],\n", " frequency=[0.5, 1.5],\n", " name=\"amplitude_warp\"\n", " )\n", "), hk.NamedParams(\n", " name=\"random_noise\",\n", " params=dict(\n", " factor=[0.1, 0.5],\n", " name=\"random_noise\"\n", " )\n", "), hk.NamedParams(\n", " name=\"random_background_noise\",\n", " params=dict(\n", " amplitude=[0.1, 0.5],\n", " num_noises=2,\n", " name=\"nstdb\"\n", " )\n", ")]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define TCN model architecture\n", "\n", "For this task, we are going to leverage a customized __TCN__ model architecture that is smaller and can handle 1D signals. The model consists of 5 TCN blocks with a depth of 1. Each block leverages dilated depthwise-separable convolutions along with inverted expansion and squeeze and excitation layers. The model is followed by a 1D convolutional layer. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "mbconv_blocks = [\n", " dict(depth=1, branch=1, filters=16, kernel=(1, 7), dilation=(1, 1), dropout=0, ex_ratio=1, se_ratio=0, norm=\"batch\"),\n", " dict(depth=1, branch=1, filters=24, kernel=(1, 7), dilation=(1, 1), dropout=0, ex_ratio=1, se_ratio=2, norm=\"batch\"),\n", " dict(depth=1, branch=1, filters=32, kernel=(1, 7), dilation=(1, 2), dropout=0, ex_ratio=1, se_ratio=2, norm=\"batch\"),\n", " dict(depth=1, branch=1, filters=40, kernel=(1, 7), dilation=(1, 4), dropout=0, ex_ratio=1, se_ratio=2, norm=\"batch\"),\n", " dict(depth=1, branch=1, filters=48, kernel=(1, 7), dilation=(1, 8), dropout=0, ex_ratio=1, se_ratio=2, norm=\"batch\")\n", "]\n", "\n", "architecture = dict(\n", " name=\"tcn\",\n", " params=dict(\n", " input_kernel=(1, 7),\n", " input_norm=\"batch\",\n", " blocks=mbconv_blocks,\n", " output_kernel=(1, 7),\n", " include_top=True,\n", " use_logits=True,\n", " model_name=\"tcn\"\n", " )\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configure datasets\n", "\n", "Capturing noise-free ECG signals is challenging due to the presence of various artifacts. Therefore, we use a combination of synthetic and controlled, real-world datasets as our training data. HeartKit exposes an ECG Synthetic dataset generator provided by PhysioKit. \n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "datasets = [\n", " hk.NamedParams(\n", " name=\"ecg-synthetic\",\n", " params=dict(\n", " num_pts=5000,\n", " params=dict(\n", " presets=[\"SR\", \"AFIB\", \"ant_STEMI\", \"LAHB\", \"LPHB\", \"high_take_off\", \"LBBB\", \"random_morphology\"],\n", " preset_weights=[24, 8, 1, 1, 1, 1, 1, 0],\n", " duration=10,\n", " sample_rate=100,\n", " heart_rate=[40, 160],\n", " impedance=[1, 2],\n", " p_multiplier=[0.7, 1.3],\n", " t_multiplier=[0.7, 1.3],\n", " noise_multiplier=[0, 0.01],\n", " voltage_factor=[800, 1000]\n", " )\n", " )\n", " ),\n", " hk.NamedParams(\n", " name=\"ptbxl\",\n", " params=dict(\n", " path=datasets_dir / \"ptbxl\",\n", " )\n", " )\n", "]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Task configuration\n", "\n", "Here we provide the constants that we will use throughout the guide. For better performance, adjust parameters as needed such as `BATCH_SIZE`, `EPOCHS`, and `LEARNING_RATE`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "params = hk.HKTaskParams(\n", " # Common arguments\n", " name=\"hk-ecg-denoiser\",\n", " job_dir=Path(tempfile.gettempdir()) / \"hk-ecg-denoiser\",\n", " # Dataset arguments\n", " datasets=datasets,\n", " # Signal arguments\n", " sampling_rate=100,\n", " frame_size=256,\n", " # Dataloader arguments\n", " samples_per_patient=5,\n", " val_samples_per_patient=10,\n", " test_samples_per_patient=10,\n", " # Preprocessing/Augmentation arguments\n", " preprocesses=preprocesses,\n", " augmentations=augmentations,\n", " # Class arguments\n", " num_classes=1,\n", " class_map={0: 0},\n", " class_names=[\"DENOISE\"],\n", " # Split arguments\n", " val_patients=0.1,\n", " val_size=10000,\n", " test_size=10000,\n", " val_file=\"val.pkl\",\n", " test_file=\"val.pkl\",\n", " # Model arguments\n", " model_file=\"model.keras\",\n", " architecture=architecture,\n", " # Training parameters\n", " lr_rate=1e-3,\n", " lr_cycles=1,\n", " batch_size=256,\n", " buffer_size=25000,\n", " epochs=100,\n", " steps_per_epoch=50,\n", " val_metric=\"loss\",\n", " class_weights=\"balanced\",\n", " # Evaluation arguments\n", " threshold=0.5,\n", " val_metric_threshold=0.98,\n", " # Export parameters\n", " tflm_var_name=\"ecg_denoise_flatbuffer\",\n", " tflm_file=\"ecg_denoise_flatbuffer.h\",\n", " # Demo params\n", " backend=\"pc\",\n", " demo_size=800,\n", " display_report=True,\n", " # Extra arguments\n", " verbose=1,\n", " seed=42\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load denoise task \n", "\n", "HeartKit provides a __TaskFactory__ that includes a number ready-to-use tasks. Each task provides methods for training, evaluating, exporting, and demoing. We will grab the __denoise__ task and configure it for our use case." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "task = hk.TaskFactory.get(\"denoise\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the datasets\n", "\n", "We will download the synthetic and PTB-XL datasets using `heartkit`. If already downloaded, this step will be skipped." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "task.download(params=params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize the data\n", "\n", "Let's visualize a sample ECG signal from the synthetic dataset. Note this contains no noise or artifacts. Augmentations will be applied later to generate noisy samples for training." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ds = hk.DatasetFactory.get(params.datasets[0].name)(\n", " cacheable=False,\n", " **params.datasets[0].params\n", ")\n", "\n", "ds_gen = ds.signal_generator(\n", " patient_generator=nse.utils.uniform_id_generator(ds.get_test_patient_ids()),\n", " frame_size=params.frame_size,\n", " samples_per_patient=params.samples_per_patient,\n", " target_rate=params.sampling_rate,\n", ")\n", "ecg = next(ds_gen)\n", "\n", "ts = np.arange(0, len(ecg)) / params.sampling_rate\n", "fig, ax = plt.subplots(1, 1, figsize=(9, 4))\n", "ax.plot(ts, ecg, color=plot_theme.primary_color, lw=3)\n", "fig.suptitle(\"Raw ECG Signal\")\n", "ax.set_xlabel(\"Time (s)\")\n", "ax.set_ylabel(\"Amplitude\")\n", "fig.tight_layout()\n", "fig.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize augmented data\n", "\n", "Let's visualize the augmented data to understand how the augmentations affect the ECG signals." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1723838156.202266 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.222145 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.222246 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.223422 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.223495 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.223541 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.268697 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.268787 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838156.268844 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "preprocessor = hk.datasets.create_augmentation_pipeline(\n", " augmentations=params.preprocesses,\n", " sampling_rate=params.sampling_rate,\n", ")\n", "\n", "augmenter = hk.datasets.create_augmentation_pipeline(\n", " augmentations=params.augmentations,\n", " sampling_rate=params.sampling_rate,\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "aug_ecg = augmenter(preprocessor(keras.ops.convert_to_tensor(np.reshape(ecg, (1, -1, 1)))), training=True)\n", "aug_ecg = aug_ecg.numpy().squeeze()\n", "\n", "ts = np.arange(0, len(aug_ecg)) / params.sampling_rate\n", "fig, ax = plt.subplots(1, 1, figsize=(9, 4))\n", "plt.plot(ts, aug_ecg, color=plot_theme.primary_color, lw=3)\n", "fig.suptitle(\"Augmented ECG Signal\")\n", "ax.set_xlabel(\"Time (s)\")\n", "ax.set_ylabel(\"Amplitude\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize the model\n", "\n", "Let's view the first several layers of the model to understand the architecture better." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"TCN\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"TCN\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)         Output Shape          Param #  Connected to      ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
       "│ inputs (InputLayer) │ (None, 256, 1)    │          0 │ -                 │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ reshape (Reshape)   │ (None, 1, 256, 1) │          0 │ inputs[0][0]      │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ ENC.CN              │ (None, 1, 256, 1) │          7 │ reshape[0][0]     │\n",
       "│ (DepthwiseConv2D)   │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ ENC.BN              │ (None, 1, 256, 1) │          4 │ ENC.CN[0][0]      │\n",
       "│ (BatchNormalizatio… │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B1.D1.DW.B1.CN      │ (None, 1, 256, 1) │          7 │ ENC.BN[0][0]      │\n",
       "│ (DepthwiseConv2D)   │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B1.D1.DW.B1.BN      │ (None, 1, 256, 1) │          4 │ B1.D1.DW.B1.CN[0… │\n",
       "│ (BatchNormalizatio… │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B1.D1.DW.ACT        │ (None, 1, 256, 1) │          0 │ B1.D1.DW.B1.BN[0… │\n",
       "│ (Activation)        │                   │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B1.D1.PW.B1.CN      │ (None, 1, 256,    │         16 │ B1.D1.DW.ACT[0][ │\n",
       "│ (Conv2D)            │ 16)               │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B1.D1.PW.B1.BN      │ (None, 1, 256,    │         64 │ B1.D1.PW.B1.CN[0… │\n",
       "│ (BatchNormalizatio…16)               │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B1.D1.PW.ACT        │ (None, 1, 256,    │          0 │ B1.D1.PW.B1.BN[0… │\n",
       "│ (Activation)        │ 16)               │            │                   │\n",
       "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
       "│ B2.D1.DW.B1.CN      │ (None, 1, 256,    │        112 │ B1.D1.PW.ACT[0][ │\n",
       "│ (DepthwiseConv2D)   │ 16)               │            │                   │\n",
       "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n", "│ inputs (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ reshape (\u001b[38;5;33mReshape\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ inputs[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ ENC.CN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7\u001b[0m │ reshape[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mDepthwiseConv2D\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ ENC.BN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │ ENC.CN[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B1.D1.DW.B1.CN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7\u001b[0m │ ENC.BN[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", "│ (\u001b[38;5;33mDepthwiseConv2D\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B1.D1.DW.B1.BN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │ B1.D1.DW.B1.CN[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B1.D1.DW.ACT │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ B1.D1.DW.B1.BN[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mActivation\u001b[0m) │ │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B1.D1.PW.B1.CN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, │ \u001b[38;5;34m16\u001b[0m │ B1.D1.DW.ACT[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ (\u001b[38;5;33mConv2D\u001b[0m) │ \u001b[38;5;34m16\u001b[0m) │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B1.D1.PW.B1.BN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, │ \u001b[38;5;34m64\u001b[0m │ B1.D1.PW.B1.CN[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mBatchNormalizatio…\u001b[0m │ \u001b[38;5;34m16\u001b[0m) │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B1.D1.PW.ACT │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │ B1.D1.PW.B1.BN[\u001b[38;5;34m0\u001b[0m… │\n", "│ (\u001b[38;5;33mActivation\u001b[0m) │ \u001b[38;5;34m16\u001b[0m) │ │ │\n", "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n", "│ B2.D1.DW.B1.CN │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m256\u001b[0m, │ \u001b[38;5;34m112\u001b[0m │ B1.D1.PW.ACT[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m…\u001b[0m │\n", "│ (\u001b[38;5;33mDepthwiseConv2D\u001b[0m) │ \u001b[38;5;34m16\u001b[0m) │ │ │\n", "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 10,223 (39.93 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m10,223\u001b[0m (39.93 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 9,675 (37.79 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m9,675\u001b[0m (37.79 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 548 (2.14 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m548\u001b[0m (2.14 KB)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = nse.models.tcn.tcn_from_object(\n", " x=keras.Input(shape=(params.frame_size, 1), name='inputs'),\n", " params=architecture[\"params\"],\n", " num_classes=1\n", ")\n", "model.summary(layer_range=('inputs', model.layers[10].name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
INFO     Creating synthetic dataset cache with 5000 patients                                   ecg_synthetic.py:159\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Creating synthetic dataset cache with \u001b[1;36m5000\u001b[0m patients \u001b]8;id=172088;file:///workspaces/heartkit/heartkit/datasets/ecg_synthetic.py\u001b\\\u001b[2mecg_synthetic.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=461477;file:///workspaces/heartkit/heartkit/datasets/ecg_synthetic.py#159\u001b\\\u001b[2m159\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Building ecg-synthetic cache: 100%|██████████| 5000/5000 [00:57<00:00, 86.91it/s] \n" ] }, { "data": { "text/html": [ "
INFO     Validation steps per epoch: 39                                                              datasets.py:85\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Validation steps per epoch: \u001b[1;36m39\u001b[0m \u001b]8;id=99779;file:///workspaces/heartkit/heartkit/tasks/denoise/datasets.py\u001b\\\u001b[2mdatasets.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=277033;file:///workspaces/heartkit/heartkit/tasks/denoise/datasets.py#85\u001b\\\u001b[2m85\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   0%|           0/100 ETA: ?s,  ?epochs/sWARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1723838225.604155  751478 service.cc:146] XLA service 0x7a52b8001f20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "I0000 00:00:1723838225.604174  751478 service.cc:154]   StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9\n",
      "I0000 00:00:1723838232.858832  751478 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n",
      "Training: 100%|██████████ 100/100 ETA: 00:00s,   1.59s/epochs"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 975us/step - cos: 0.7118 - loss: 0.0511 - mae: 0.1445 - mse: 0.0452 - snr: 11.9220\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "
INFO     [VAL SET]COS=0.7079, LOSS=0.0528, MAE=0.1466, MSE=0.0469, SNR=11.9038                         train.py:149\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mVAL SET\u001b[1m]\u001b[0m\u001b[33mCOS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.7079\u001b[0m, \u001b[33mLOSS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0528\u001b[0m, \u001b[33mMAE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.1466\u001b[0m, \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0469\u001b[0m, \u001b[33mSNR\u001b[0m=\u001b[1;36m11\u001b[0m\u001b[1;36m.9038\u001b[0m \u001b]8;id=347748;file:///workspaces/heartkit/heartkit/tasks/denoise/train.py\u001b\\\u001b[2mtrain.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=260161;file:///workspaces/heartkit/heartkit/tasks/denoise/train.py#149\u001b\\\u001b[2m149\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "task.train(params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model evaluation\n", "\n", "Now that we have trained the model, we will evaluate the model on the test dataset. Similar to training, we will provide the high-level configuration to the task process." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
INFO     Creating synthetic dataset cache with 5000 patients                                   ecg_synthetic.py:159\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Creating synthetic dataset cache with \u001b[1;36m5000\u001b[0m patients \u001b]8;id=288389;file:///workspaces/heartkit/heartkit/datasets/ecg_synthetic.py\u001b\\\u001b[2mecg_synthetic.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=256787;file:///workspaces/heartkit/heartkit/datasets/ecg_synthetic.py#159\u001b\\\u001b[2m159\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Building ecg-synthetic cache: 100%|██████████| 5000/5000 [00:57<00:00, 87.16it/s] \n" ] }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m39/39\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 25ms/step - cos: 0.7238 - loss: 0.0443 - mae: 0.1328 - mse: 0.0384 - snr: 12.3671\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "
INFO     [TEST SET] COS=0.7245, LOSS=0.0437, MAE=0.1316, MSE=0.0377, SNR=12.3787                     evaluate.py:37\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mTEST SET\u001b[1m]\u001b[0m \u001b[33mCOS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.7245\u001b[0m, \u001b[33mLOSS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0437\u001b[0m, \u001b[33mMAE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.1316\u001b[0m, \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0377\u001b[0m, \u001b[33mSNR\u001b[0m=\u001b[1;36m12\u001b[0m\u001b[1;36m.3787\u001b[0m \u001b]8;id=893749;file:///workspaces/heartkit/heartkit/tasks/denoise/evaluate.py\u001b\\\u001b[2mevaluate.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=218337;file:///workspaces/heartkit/heartkit/tasks/denoise/evaluate.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "task.evaluate(params)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export model to TF Lite / TFLM\n", "\n", "Once we have trained and evaluated the model, we need to export the model into a format that can be used for inference on the edge. Currently, we export the model to TensorFlow Lite flatbuffer format. This will also generate a C header file that can be used with TensorFlow Lite for Microcontrollers (TFLM).\n", "\n", "For this model, we will export as a 32-bit floating point model.\n", " \n", "__NOTE:__ We utilize `CONCRETE` mode to lower the model to concrete functions before converting. This is because TF (MLIR) fails to properly lower the dilated convolutional layers." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "quantization = hk.QuantizationParams(\n", " enabled=True,\n", " format=\"FP32\",\n", " io_type=\"float32\",\n", " conversion=\"CONCRETE\",\n", ")\n", "params.quantization = quantization" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
INFO     Creating synthetic dataset cache with 5000 patients                                   ecg_synthetic.py:159\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Creating synthetic dataset cache with \u001b[1;36m5000\u001b[0m patients \u001b]8;id=313048;file:///workspaces/heartkit/heartkit/datasets/ecg_synthetic.py\u001b\\\u001b[2mecg_synthetic.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=514688;file:///workspaces/heartkit/heartkit/datasets/ecg_synthetic.py#159\u001b\\\u001b[2m159\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
[08/16/24 20:02:23] WARNING  WARNING:absl:Please consider providing the trackable_obj argument in the  lite.py:2166\n",
       "                             from_concrete_functions. Providing without the trackable_obj argument is              \n",
       "                             deprecated and it will use the deprecated conversion path.                            \n",
       "
\n" ], "text/plain": [ "\u001b[2;36m[08/16/24 20:02:23]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m WARNING:absl:Please consider providing the trackable_obj argument in the \u001b]8;id=520246;file:///workspaces/heartkit/.venv/lib/python3.12/site-packages/tensorflow/lite/python/lite.py\u001b\\\u001b[2mlite.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=384487;file:///workspaces/heartkit/.venv/lib/python3.12/site-packages/tensorflow/lite/python/lite.py#2166\u001b\\\u001b[2m2166\u001b[0m\u001b]8;;\u001b\\\n", "\u001b[2;36m \u001b[0m from_concrete_functions. Providing without the trackable_obj argument is \u001b[2m \u001b[0m\n", "\u001b[2;36m \u001b[0m deprecated and it will use the deprecated conversion path. \u001b[2m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
INFO     Validating model results                                                                      export.py:83\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Validating model results \u001b]8;id=941295;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py\u001b\\\u001b[2mexport.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=727514;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py#83\u001b\\\u001b[2m83\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1723838543.688860 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838543.688944 751181 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1\n", "I0000 00:00:1723838543.689113 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838543.689169 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838543.689214 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838543.689287 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723838543.689333 751181 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "W0000 00:00:1723838543.815333 751181 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.\n", "W0000 00:00:1723838543.815348 751181 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.\n", "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" ] }, { "data": { "text/html": [ "
INFO     [TF METRICS] LOSS=0.0396 MAE=0.1357 MSE=0.0396 RMSE=0.1991 COSINE=0.7178                      export.py:90\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mTF METRICS\u001b[1m]\u001b[0m \u001b[33mLOSS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0396\u001b[0m \u001b[33mMAE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.1357\u001b[0m \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0396\u001b[0m \u001b[33mRMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.1991\u001b[0m \u001b[33mCOSINE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.7178\u001b[0m \u001b]8;id=496666;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py\u001b\\\u001b[2mexport.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=35165;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py#90\u001b\\\u001b[2m90\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
INFO     [TFL METRICS] LOSS=0.0396 MAE=0.1357 MSE=0.0396 RMSE=0.1991 COSINE=0.7177                     export.py:91\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mTFL METRICS\u001b[1m]\u001b[0m \u001b[33mLOSS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0396\u001b[0m \u001b[33mMAE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.1357\u001b[0m \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0396\u001b[0m \u001b[33mRMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.1991\u001b[0m \u001b[33mCOSINE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.7177\u001b[0m \u001b]8;id=70190;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py\u001b\\\u001b[2mexport.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=68341;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py#91\u001b\\\u001b[2m91\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
INFO     Validation passed (0.0000)                                                                    export.py:99\n",
       "
\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Validation passed \u001b[1m(\u001b[0m\u001b[1;36m0.0000\u001b[0m\u001b[1m)\u001b[0m \u001b]8;id=375015;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py\u001b\\\u001b[2mexport.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=497680;file:///workspaces/heartkit/heartkit/tasks/denoise/export.py#99\u001b\\\u001b[2m99\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# TF dumps a lot of info to stdout, so we redirect it to /dev/null\n", "with open(os.devnull, 'w') as devnull:\n", " with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):\n", " task.export(params)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ECG Denoising Demo\n", "\n", "Finally, we will demonstrate how to use the trained ECG denoiser model to remove noise and artifacts from raw ECG signals. We will load a sample ECG signal, add noise to it, and then denoise it using the trained model. We will visualize the original, noisy, and denoised ECG signals to compare the results." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "model = nse.models.load_model(params.model_file)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 9ms/step\n" ] } ], "source": [ "ecg = next(ds_gen)\n", "aug_ecg = augmenter(preprocessor(keras.ops.convert_to_tensor(np.reshape(ecg, (1, -1, 1)))), training=True).numpy().squeeze()\n", "clean_ecg = model.predict(np.reshape(aug_ecg, (1, -1, 1)))\n", "snr = nse.metrics.Snr()\n", "snr.update_state(ecg.reshape(1, -1, 1), aug_ecg.reshape(1, -1, 1))\n", "aug_snr = snr.result().numpy()\n", "snr.reset_state()\n", "snr.update_state(ecg.reshape(1, -1, 1), clean_ecg.reshape(1, -1, 1))\n", "clean_snr = snr.result().numpy()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(3, 1, figsize=(9, 5), sharex=True)\n", "ax[0].plot(ts, ecg.squeeze(), color=plot_theme.primary_color, lw=3)\n", "ax[1].plot(ts, aug_ecg.squeeze(), color=plot_theme.secondary_color, lw=3)\n", "ax[2].plot(ts, clean_ecg.squeeze(), color=plot_theme.tertiary_color, lw=3)\n", "\n", "ax[0].set_ylabel(\"Reference\")\n", "ax[1].set_ylabel(\"Noisy\")\n", "ax[2].set_ylabel(\"Denoised\")\n", "\n", "ax[1].text(0.98, 0.15, f\"{aug_snr:4.02f} dB SNR\", transform=ax[1].transAxes, ha=\"right\", va=\"top\", weight='bold')\n", "ax[2].text(0.98, 0.15, f\"{clean_snr:4.02f} dB SNR\", transform=ax[2].transAxes, ha=\"right\", va=\"top\", weight='bold')\n", "# Disable y-axis ticks for all plots\n", "for axes in ax:\n", " axes.yaxis.set_ticks([])\n", "ax[-1].set_xlabel(\"Time (s)\")\n", "fig.suptitle(\"ECG Denoising Demo\")\n", "fig.tight_layout()\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" } }, "nbformat": 4, "nbformat_minor": 2 }