Train Sleep Detection Model¶
Date created: 2025/10/02
Last Modified: 2025/10/02
Description: Train a simple wrist-based sleep detection model using accelerometer data.
Overview¶
In this guide, we will train a small TCN network to detect sleep and wake stages using accelerometer data collected from the wrist.
Input
- Sensor: IMU
- Location: Wrist
- Sampling Rate: 0.2 Hz
- Frame Size: 60 seconds
Class Mapping
Identify activity into one of two categories: SLEEP, AWAKE.
| Base Class | Target Class | Label |
|---|---|---|
| 0-WAKE | 0 | WAKE |
| 1-SLEEP | 1 | SLEEP |
Datasets
- CMIDSS: The Child Mind Institute - Detect Sleep States (CMIDSS) dataset comprises 300 subjects with over 500 multi-day recordings of wrist-worn accelerometer data annotated with two event types: onset, the beginning of sleep, and wakeup, the end of sleep.
Setup¶
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import IPython
import contextlib
import tempfile
from pathlib import Path
import keras
import helia_edge as helia
import sleepkit as sk
/workspaces/sleepkit/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
# Be sure to set the dataset path to the correct location
datasets_dir = Path(os.getenv('SK_DATASET_PATH', '../../datasets'))
plot_theme = sk.utils.dark_theme
helia.utils.silence_tensorflow()
sk.utils.setup_plotting(plot_theme)
logger = helia.utils.setup_logger(__name__)
Configure datasets¶
We are going to train our model using the [CMIDSS Dataset(https://ambiqai.github.io/sleepkit/datasets/cmidss/). This dataset uses the slug cmidss within sleepKIT. We will download the dataset if it is not already available.
datasets = [sk.NamedParams(
name="cmidss",
params=dict(
path=datasets_dir / "cmidss",
)
)]
Target classes¶
For this task, we are going to simply classify the data into two classes: SLEEP and AWAKE.
class_map = {
sk.SleepStage.wake: 0,
sk.SleepStage.stage1: 1,
sk.SleepStage.stage2: 1,
sk.SleepStage.stage3: 1,
sk.SleepStage.stage4: 1,
sk.SleepStage.rem: 1,
}
class_names = ["WAKE", "SLEEP"]
Feature set¶
From the dataset, we will create a feature set using the FS-W-A-5 features. This feature set computes 5 features over 60-second windows captured from the accelerometer sensor collected on the wrist. The CMIDSS dataset already provides accelerometer averaged over 5 secods (i.e. Fs=0.2 Hz). Therefore, we will use a frame size of 12 to capture 60 seconds of data (i.e. 6 samples at 0.2 Hz) with a 50% overlap.
feature = dict(
name="FS-W-A-5",
sampling_rate=0.2,
frame_size=12,
loader="hdf5",
feat_key="features",
label_key="detect_labels",
mask_key="mask",
feat_cols=None,
save_path=datasets_dir / "store" / "fs-w-a-5-60",
params={},
)
Define TCN model architecture¶
For this task, we are going to leverage a customized TCN model architecture that is smaller and can handle 1D signals. The model consists of 4 TCN blocks with a depth of 1. Each block leverages dilated depthwise-separable convolutions along with inverted expansion and squeeze and excitation layers. The model is followed by a 1D convolutional layer and a final dense layer for regression. Unlike vision tasks, we leverage larger kernel sizes and strides to capture temporal dependencies in the signal.
architecture = sk.NamedParams(
name="tcn",
params=dict(
input_kernel=[1, 5],
input_norm="batch",
blocks=[
dict(depth=1, branch=1, filters=16, kernel=(1, 5), dilation=[1, 1], dropout=0.10, ex_ratio=1, se_ratio=4, norm="batch"),
dict(depth=1, branch=1, filters=32, kernel=(1, 5), dilation=[1, 2], dropout=0.10, ex_ratio=1, se_ratio=4, norm="batch"),
dict(depth=1, branch=1, filters=48, kernel=(1, 5), dilation=[1, 4], dropout=0.10, ex_ratio=1, se_ratio=4, norm="batch"),
dict(depth=1, branch=1, filters=64, kernel=(1, 5), dilation=[1, 8], dropout=0.10, ex_ratio=1, se_ratio=4, norm="batch")
],
output_kernel=(1, 5),
include_top=True,
use_logits=True,
model_name="tcn"
)
)
Task configuration¶
Here we provide the complete configuration for the task. This includes the dataset configuration, features, model architecture, and training parameters.
params = sk.TaskParams(
name="sk-detect",
job_dir=Path(tempfile.gettempdir()) / "sk-detect",
verbose=1,
datasets=datasets,
feature=feature,
sampling_rate=0.0083333,
frame_size=240,
num_classes=len(class_names),
class_map=class_map,
class_names=class_names,
samples_per_subject=100,
val_samples_per_subject=100,
test_samples_per_subject=50,
val_size=4000,
test_size=2500,
val_subjects=0.20,
batch_size=128,
buffer_size=10000,
epochs=200,
steps_per_epoch=25,
val_steps_per_epoch=25,
val_metric="loss",
lr_rate=1e-3,
lr_cycles=1,
label_smoothing=0,
test_metric="f1",
test_metric_threshold=0.02,
tflm_var_name="sk_detect_flatbuffer",
tflm_file="sk_detect_flatbuffer.h",
backend="pc",
display_report=False,
model_file="model.keras",
use_logits=False,
architecture=architecture
)
Load detect task¶
sleepKIT provides a TaskFactory that includes a number ready-to-use tasks. Each task provides methods for training, evaluating, exporting, and demoing. We will grab the detect task and configure it for our use case.
task = sk.TaskFactory.get("detect")
Download the datasets¶
We will download the datasets using the sleepkit library. If already downloaded, this step will be skipped.
task.download(params=params)
Generate the features¶
Next, we will generate the features from the given dataset. The features will be generated using the fs_w_a_5 feature set.
Once the command finishes, the feature set will be saved in the feature.save_path directory. These features will be stored in HDF5 files with one file per subject. Each HDF5 file will include the following entries:
/features: Time x Feature tensor (fp32). Features are computed over windows of sensor data./mask: Time x Mask tensor (bool). Mask indicates valid feature values./detect_labels: Time x Label (int). Labels are awake/sleep.
task.feature(params=params)
Gen features for cmidss: 100%|██████████| 277/277 [02:01<00:00, 2.28it/s]
Visualize the model¶
Lets quickly instantiate and visualize the model.
model = helia.models.tcn.TcnModel.model_from_params(
inputs=keras.Input(shape=(params.frame_size, 5), name="inputs"),
params=architecture.params,
num_classes=len(class_names)
)
model.summary(layer_range=('inputs', model.layers[10].name))
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1765845289.450533 77621 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1329 MB memory: -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:01:00.0, compute capability: 8.9
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ inputs (InputLayer) │ (None, 240, 5) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ reshape (Reshape) │ (None, 1, 240, 5) │ 0 │ inputs[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ENC_CN │ (None, 1, 240, 5) │ 25 │ reshape[0][0] │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ENC_BN │ (None, 1, 240, 5) │ 20 │ ENC_CN[0][0] │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_DW_B1_CN │ (None, 1, 240, 5) │ 25 │ ENC_BN[0][0] │ │ (DepthwiseConv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_DW_B1_BN │ (None, 1, 240, 5) │ 20 │ B1_D1_DW_B1_CN[0… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_DW_ACT │ (None, 1, 240, 5) │ 0 │ B1_D1_DW_B1_BN[0… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_SE_pool │ (None, 1, 1, 5) │ 0 │ B1_D1_DW_ACT[0][… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_SE_sq │ (None, 1, 1, 1) │ 6 │ B1_D1_SE_pool[0]… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_SE_sq.act │ (None, 1, 1, 1) │ 0 │ B1_D1_SE_sq[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_SE_ex │ (None, 1, 1, 5) │ 10 │ B1_D1_SE_sq.act[… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_SE_ex.act │ (None, 1, 1, 5) │ 0 │ B1_D1_SE_ex[0][0] │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ B1_D1_SE_ex.mul │ (None, 1, 240, 5) │ 0 │ B1_D1_DW_ACT[0][… │ │ (Multiply) │ │ │ B1_D1_SE_ex.act[… │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 9,364 (36.58 KB)
Trainable params: 8,832 (34.50 KB)
Non-trainable params: 532 (2.08 KB)
Train the model¶
At this point, we can train the model from the generated feature set for the sleep detect task. The model will be trained for 200 epochs with a batch size of 128 and a learning rate of 1e-3. The model will be fed a frame_size of 240 samples which equates to 120 minutes.
Using the task configuration, we will train the model on the dataset.
task.train(params)
Epoch 1/200
Model evaluation¶
Now that we have trained the model, we will evaluate the model on the test dataset. Similar to training, we will provide the high-level configuration to the task process.
task.evaluate(params)
Subject: 100%|██████████| 56/56 [00:17<00:00, 3.21it/s]
INFO Testing Results evaluate.py:130
19/19 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - acc: 0.9436 - f1: 0.9448 - iou: 0.8600 - loss: 0.0101
INFO [TEST SET] acc=94.59%, f1=94.65%, iou=86.69%, loss=0.97% evaluate.py:132
Confusion matrix¶
Let's visualize the confusion matrix to understand the model's performance on each class.
IPython.display.Image(filename=params.job_dir / "confusion_matrix_test.png", width=500)
Export model to TF Lite / TFLM¶
Once we have trained and evaluated the model, we need to export the model into a format that can be used for inference on the edge. Currently, we export the model to TensorFlow Lite flatbuffer format. This will also generate a C header file that can be used with TensorFlow Lite for Microcontrollers (TFLM).
Apply post-training quantization (PTQ)¶
For running on bare metal, we will perform post-training quantization to convert the model to an 8-bit integer model. The weights and activations will be quantized to 8-bits and biases will be quantized to 32-bits. This will reduce the model size and improve the inference speed.
quantization = sk.QuantizationParams(
enabled=True,
format="FP32",
io_type="float32",
conversion="CONCRETE",
)
params.quantization = quantization
# TF dumps a lot of info to stdout, so we redirect it to /dev/null
with open(os.devnull, 'w') as devnull:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
task.export(params)
I0000 00:00:1727881289.426610 5610 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727881289.426701 5610 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1 I0000 00:00:1727881289.427005 5610 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727881289.427061 5610 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727881289.427104 5610 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727881289.427172 5610 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727881289.427216 5610 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 W0000 00:00:1727881289.517520 5610 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format. W0000 00:00:1727881289.517532 5610 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.
Run inference demo¶
We will run a demo on the PC to verify that the model is working as expected. The demo will load the model and run inferences across a randomly selected subject. The demo will also provide the model's prediction and the corresponding class name.
task.demo(params=params)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1727883197.048706 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.068145 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.068255 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.069518 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.069592 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.069636 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.117509 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.117605 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1727883197.117663 223691 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 Inference: 0%| | 0/296 [00:00<?, ?it/s]WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1727883197.752020 223865 service.cc:146] XLA service 0x741ed800fff0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1727883197.752040 223865 service.cc:154] StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9 I0000 00:00:1727883198.494078 223865 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. Inference: 100%|██████████| 296/296 [00:06<00:00, 42.97it/s]