{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
INFO Job directory: /tmp/hk-foundation 1079341004.py:6\n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Job directory: \u001b[35m/tmp/\u001b[0m\u001b[95mhk-foundation\u001b[0m \u001b]8;id=625876;file:///tmp/ipykernel_712291/1079341004.py\u001b\\\u001b[2m1079341004.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=4474;file:///tmp/ipykernel_712291/1079341004.py#6\u001b\\\u001b[2m6\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "nse.utils.silence_tensorflow()\n", "hk.utils.setup_plotting(plot_theme)\n", "logger = nse.utils.setup_logger(__name__, level=verbose)\n", "\n", "os.makedirs(job_dir, exist_ok=True)\n", "logger.info(f\"Job directory: {job_dir}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configure datasets\n", "\n", "We are going to train our model using two large datasets: the PTB-XL dataset and the large-scale arrhythmia dataset. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "datasets = [\n", " hk.NamedParams(\n", " name=\"lsad\",\n", " params=dict(\n", " path=datasets_dir / \"lsad\"\n", " )\n", " ),\n", " hk.NamedParams(\n", " name=\"ptbxl\",\n", " params=dict(\n", " path=datasets_dir / \"ptbxl\"\n", " )\n", " ),\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download datasets\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "for dataset in datasets:\n", " ds = hk.DatasetFactory.get(dataset.name)(\n", " **dataset.params\n", " )\n", " ds.download(force=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create data pipeline\n", "\n", "Next, we will create a `tf.data` pipeline by performing the following steps on each dataset: \n", "* Loading dataset class handler \n", "* Leverage task specific data loader for given dataset\n", "* Splittiing the dataset into training and validation sets\n", "* Creating `tf.data.Dataset` objects for training and validation\n", "\n", "After creating all the `tf.data.Dataset` objects, we will merge them into a single dataset for training and validation. \n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Load datasets\n", "dsets = [hk.DatasetFactory.get(ds.name)(**ds.params) for ds in datasets]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1723834403.812869 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.835711 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.835842 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.837216 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.837303 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.837349 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.890424 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.890527 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "I0000 00:00:1723834403.890585 712291 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n" ] } ], "source": [ "dset_weights = np.array([0.5, 0.5])\n", "\n", "train_datasets = []\n", "val_datasets = []\n", "for ds in dsets:\n", "\n", " # Create dataloader specific to dataset\n", " dataloader = hk.tasks.foundation.FoundationTaskFactory.get(ds.name)(\n", " ds=ds,\n", " frame_size=frame_size,\n", " sampling_rate=sampling_rate,\n", " )\n", "\n", " # Split patients into train and validation sets\n", " train_patients, val_patients = dataloader.split_train_val_patients()\n", "\n", " # Create train dataset\n", " train_ds = dataloader.create_dataloader(\n", " patient_ids=train_patients,\n", " samples_per_patient=samples_per_patient,\n", " shuffle=True\n", " )\n", "\n", " # Create validation dataset\n", " val_ds = dataloader.create_dataloader(\n", " patient_ids=val_patients,\n", " samples_per_patient=samples_per_patient,\n", " shuffle=False\n", " )\n", " train_datasets.append(train_ds)\n", " val_datasets.append(val_ds)\n", "# END FOR\n", "\n", "# Combine datasets\n", "train_ds = tf.data.Dataset.sample_from_datasets(train_datasets, weights=dset_weights)\n", "val_ds = tf.data.Dataset.sample_from_datasets(val_datasets, weights=dset_weights)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the data\n", "\n", "Let's visualize a sample ECG signal from the synthetic dataset. Note this contains no noise or artifacts. Augmentations will be applied later to generate noisy samples for training." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
INFO Model: \"EfficientNetV2\" summary_utils.py:389\n", " ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ \n", " ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ \n", " ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ \n", " │ input (InputLayer) │ (None, 800, 1) │ 0 │ - │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ reshape (Reshape) │ (None, 1, 800, 1) │ 0 │ input[0][0] │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stem.conv (Conv2D) │ (None, 1, 400, │ 216 │ reshape[0][0] │ \n", " │ │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stem.bn │ (None, 1, 400, │ 96 │ stem.conv[0][0] │ \n", " │ (BatchNormalizatio… │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stem.act │ (None, 1, 400, │ 0 │ stem.bn[0][0] │ \n", " │ (Activation) │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stage1.mbconv1.dp │ (None, 1, 400, │ 216 │ stem.act[0][0] │ \n", " │ (DepthwiseConv2D) │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stage1.mbconv1.dp.… │ (None, 1, 400, │ 96 │ stage1.mbconv1.d… │ \n", " │ (BatchNormalizatio… │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stage1.mbconv1.dp.… │ (None, 1, 400, │ 0 │ stage1.mbconv1.d… │ \n", " │ (Activation) │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ max_pooling2d │ (None, 1, 200, │ 0 │ stage1.mbconv1.d… │ \n", " │ (MaxPooling2D) │ 24) │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stage1.mbconv1.se.… │ (None, 1, 1, 24) │ 0 │ max_pooling2d[0]… │ \n", " │ (GlobalAveragePool… │ │ │ │ \n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \n", " │ stage1.mbconv1.se.… │ (None, 1, 1, 6) │ 150 │ stage1.mbconv1.s… │ \n", " │ (Conv2D) │ │ │ │ \n", " └─────────────────────┴───────────────────┴────────────┴───────────────────┘ \n", " Total params: 57,066 (222.91 KB) \n", " Trainable params: 55,050 (215.04 KB) \n", " Non-trainable params: 2,016 (7.88 KB) \n", " \n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Model: \u001b[32m\"EfficientNetV2\"\u001b[0m \u001b]8;id=445330;file:///workspaces/heartkit/.venv/lib/python3.12/site-packages/keras/src/utils/summary_utils.py\u001b\\\u001b[2msummary_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=349863;file:///workspaces/heartkit/.venv/lib/python3.12/site-packages/keras/src/utils/summary_utils.py#389\u001b\\\u001b[2m389\u001b[0m\u001b]8;;\u001b\\\n", " ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ \u001b[2m \u001b[0m\n", " ┃ Layer \u001b[1m(\u001b[0mtype\u001b[1m)\u001b[0m ┃ Output Shape ┃ Param # ┃ Connected to ┃ \u001b[2m \u001b[0m\n", " ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ \u001b[2m \u001b[0m\n", " │ input \u001b[1m(\u001b[0mInputLayer\u001b[1m)\u001b[0m │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m800\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m0\u001b[0m │ - │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ reshape \u001b[1m(\u001b[0mReshape\u001b[1m)\u001b[0m │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m800\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m0\u001b[0m │ input\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stem.conv \u001b[1m(\u001b[0mConv2D\u001b[1m)\u001b[0m │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m400\u001b[0m, │ \u001b[1;36m216\u001b[0m │ reshape\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m │ \u001b[2m \u001b[0m\n", " │ │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stem.bn │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m400\u001b[0m, │ \u001b[1;36m96\u001b[0m │ stem.conv\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mBatchNormalizatio… │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stem.act │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m400\u001b[0m, │ \u001b[1;36m0\u001b[0m │ stem.bn\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mActivation\u001b[1m)\u001b[0m │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stage1.mbconv1.dp │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m400\u001b[0m, │ \u001b[1;36m216\u001b[0m │ stem.act\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mDepthwiseConv2D\u001b[1m)\u001b[0m │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stage1.mbconv1.dp.… │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m400\u001b[0m, │ \u001b[1;36m96\u001b[0m │ stage1.mbconv1.d… │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mBatchNormalizatio… │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stage1.mbconv1.dp.… │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m400\u001b[0m, │ \u001b[1;36m0\u001b[0m │ stage1.mbconv1.d… │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mActivation\u001b[1m)\u001b[0m │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ max_pooling2d │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m200\u001b[0m, │ \u001b[1;36m0\u001b[0m │ stage1.mbconv1.d… │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mMaxPooling2D\u001b[1m)\u001b[0m │ \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stage1.mbconv1.se.… │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m24\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m0\u001b[0m │ max_pooling2d\u001b[1m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m]\u001b[0m… │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mGlobalAveragePool… │ │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ \u001b[2m \u001b[0m\n", " │ stage1.mbconv1.se.… │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m1\u001b[0m, \u001b[1;36m6\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m150\u001b[0m │ stage1.mbconv1.s… │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mConv2D\u001b[1m)\u001b[0m │ │ │ │ \u001b[2m \u001b[0m\n", " └─────────────────────┴───────────────────┴────────────┴───────────────────┘ \u001b[2m \u001b[0m\n", " Total params: \u001b[1;36m57\u001b[0m,\u001b[1;36m066\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m222.91\u001b[0m KB\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n", " Trainable params: \u001b[1;36m55\u001b[0m,\u001b[1;36m050\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m215.04\u001b[0m KB\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n", " Non-trainable params: \u001b[1;36m2\u001b[0m,\u001b[1;36m016\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m7.88\u001b[0m KB\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n", " \u001b[2m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
INFO Computation: 4.17 MFLOPs 909537700.py:3\n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Computation: \u001b[1;36m4.17\u001b[0m MFLOPs \u001b]8;id=614122;file:///tmp/ipykernel_712291/909537700.py\u001b\\\u001b[2m909537700.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=675398;file:///tmp/ipykernel_712291/909537700.py#3\u001b\\\u001b[2m3\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "encoder.summary(print_fn=logger.info, layer_range=('input', encoder.layers[10].name))\n", "flops = nse.metrics.flops.get_flops(encoder, batch_size=1, fpath=os.devnull)\n", "logger.info(f\"Computation: {flops/1e6:0.2f} MFLOPs\")\n", "encoder_output = encoder(inputs)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
INFO Model: \"projector\" summary_utils.py:389\n", " ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ \n", " ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ \n", " ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ \n", " │ keras_tensor_109CLONE │ (None, 128) │ 0 │ \n", " │ (InputLayer) │ │ │ \n", " ├─────────────────────────────────┼────────────────────────┼───────────────┤ \n", " │ dense (Dense) │ (None, 128) │ 16,512 │ \n", " ├─────────────────────────────────┼────────────────────────┼───────────────┤ \n", " │ dense_1 (Dense) │ (None, 128) │ 16,512 │ \n", " └─────────────────────────────────┴────────────────────────┴───────────────┘ \n", " Total params: 33,024 (129.00 KB) \n", " Trainable params: 33,024 (129.00 KB) \n", " Non-trainable params: 0 (0.00 B) \n", " \n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m Model: \u001b[32m\"projector\"\u001b[0m \u001b]8;id=439076;file:///workspaces/heartkit/.venv/lib/python3.12/site-packages/keras/src/utils/summary_utils.py\u001b\\\u001b[2msummary_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=568200;file:///workspaces/heartkit/.venv/lib/python3.12/site-packages/keras/src/utils/summary_utils.py#389\u001b\\\u001b[2m389\u001b[0m\u001b]8;;\u001b\\\n", " ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ \u001b[2m \u001b[0m\n", " ┃ Layer \u001b[1m(\u001b[0mtype\u001b[1m)\u001b[0m ┃ Output Shape ┃ Param # ┃ \u001b[2m \u001b[0m\n", " ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ \u001b[2m \u001b[0m\n", " │ keras_tensor_109CLONE │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m128\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m0\u001b[0m │ \u001b[2m \u001b[0m\n", " │ \u001b[1m(\u001b[0mInputLayer\u001b[1m)\u001b[0m │ │ │ \u001b[2m \u001b[0m\n", " ├─────────────────────────────────┼────────────────────────┼───────────────┤ \u001b[2m \u001b[0m\n", " │ dense \u001b[1m(\u001b[0mDense\u001b[1m)\u001b[0m │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m128\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m16\u001b[0m,\u001b[1;36m512\u001b[0m │ \u001b[2m \u001b[0m\n", " ├─────────────────────────────────┼────────────────────────┼───────────────┤ \u001b[2m \u001b[0m\n", " │ dense_1 \u001b[1m(\u001b[0mDense\u001b[1m)\u001b[0m │ \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m128\u001b[0m\u001b[1m)\u001b[0m │ \u001b[1;36m16\u001b[0m,\u001b[1;36m512\u001b[0m │ \u001b[2m \u001b[0m\n", " └─────────────────────────────────┴────────────────────────┴───────────────┘ \u001b[2m \u001b[0m\n", " Total params: \u001b[1;36m33\u001b[0m,\u001b[1;36m024\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m129.00\u001b[0m KB\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n", " Trainable params: \u001b[1;36m33\u001b[0m,\u001b[1;36m024\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m129.00\u001b[0m KB\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n", " Non-trainable params: \u001b[1;36m0\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m0.00\u001b[0m B\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n", " \u001b[2m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "projector_input = encoder_output\n", "projector_output = keras.layers.Dense(projection_width, activation=\"relu6\")(projector_input)\n", "projector_output = keras.layers.Dense(projection_width)(projector_output)\n", "projector = keras.Model(inputs=projector_input, outputs=projector_output, name=\"projector\")\n", "flops = nse.metrics.flops.get_flops(projector, batch_size=1, fpath=os.devnull)\n", "projector.summary(print_fn=logger.info)\n", "logger.debug(f\"Projector requires {flops/1e6:0.2f} MFLOPS\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a SimCLR model to train" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "model = nse.trainers.SimCLRTrainer(\n", " encoder=encoder,\n", " augmenter=None, # We augment in the data pipeline\n", " projector=projector,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compile the model\n", "\n", "We will compile the model using Adam optimizer with cosine learning rate scheduler and custom cosine similarity loss function. We will also attach metrics and callbacks to monitor the training process.\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_scheduler():\n", " return keras.optimizers.schedules.CosineDecay(\n", " initial_learning_rate=learning_rate,\n", " decay_steps=steps_per_epoch * epochs,\n", " )\n", "\n", "optimizer = keras.optimizers.Adam(get_scheduler())\n", "loss = nse.losses.simclr.SimCLRLoss(temperature=temperature)\n", "\n", "metrics = [\n", " keras.metrics.MeanSquaredError(name=\"mse\"),\n", " keras.metrics.CosineSimilarity(name=\"cos\"),\n", "]\n", "\n", "model_callbacks = [\n", " keras.callbacks.EarlyStopping(\n", " monitor=f\"val_{val_metric}\",\n", " patience=max(int(0.25 * epochs), 1),\n", " mode=val_mode,\n", " restore_best_weights=True,\n", " verbose=verbose - 1\n", " ),\n", " keras.callbacks.ModelCheckpoint(\n", " filepath=str(model_file),\n", " monitor=f\"val_{val_metric}\",\n", " save_best_only=True,\n", " mode=val_mode,\n", " verbose=verbose - 1\n", " ),\n", " keras.callbacks.CSVLogger(job_dir / \"history.csv\"),\n", "]\n", "if nse.utils.env_flag(\"TENSORBOARD\"):\n", " model_callbacks.append(\n", " keras.callbacks.TensorBoard(\n", " log_dir=job_dir,\n", " write_steps_per_second=True,\n", " )\n", " )\n", "\n", "model.compile(\n", " encoder_optimizer=optimizer,\n", " encoder_loss=loss,\n", " encoder_metrics=metrics,\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/150\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-08-16 18:54:13.839587: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT32 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1723834463.457755 712486 service.cc:146] XLA service 0x78321c02f130 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "I0000 00:00:1723834463.457771 712486 service.cc:154] StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m 1/25\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m13:39\u001b[0m 34s/step - cos: 0.5956 - loss: 15.6336 - mse: 0.2352" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1723834487.410060 712486 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m67s\u001b[0m 1s/step - cos: 0.6157 - loss: 14.9098 - mse: 0.2319 - val_cos: 0.6770 - val_loss: 12.6894 - val_mse: 0.2770\n", "Epoch 2/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 228ms/step - cos: 0.6928 - loss: 12.2036 - mse: 0.2814 - val_cos: 0.7274 - val_loss: 11.2915 - val_mse: 0.2797\n", "Epoch 3/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7322 - loss: 11.1098 - mse: 0.2783 - val_cos: 0.7428 - val_loss: 10.5851 - val_mse: 0.2743\n", "Epoch 4/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 185ms/step - cos: 0.7449 - loss: 10.4056 - mse: 0.2715 - val_cos: 0.7517 - val_loss: 9.9517 - val_mse: 0.2724\n", "Epoch 5/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 184ms/step - cos: 0.7523 - loss: 9.8387 - mse: 0.2707 - val_cos: 0.7568 - val_loss: 9.5624 - val_mse: 0.2703\n", "Epoch 6/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7548 - loss: 9.5425 - mse: 0.2690 - val_cos: 0.7591 - val_loss: 9.2802 - val_mse: 0.2633\n", "Epoch 7/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7587 - loss: 9.2489 - mse: 0.2617 - val_cos: 0.7604 - val_loss: 9.0665 - val_mse: 0.2585\n", "Epoch 8/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 185ms/step - cos: 0.7604 - loss: 9.0068 - mse: 0.2579 - val_cos: 0.7623 - val_loss: 8.8123 - val_mse: 0.2564\n", "Epoch 9/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 182ms/step - cos: 0.7618 - loss: 8.7503 - mse: 0.2550 - val_cos: 0.7628 - val_loss: 8.5923 - val_mse: 0.2538\n", "Epoch 10/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7621 - loss: 8.5523 - mse: 0.2549 - val_cos: 0.7622 - val_loss: 8.4131 - val_mse: 0.2523\n", "Epoch 11/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 186ms/step - cos: 0.7624 - loss: 8.3957 - mse: 0.2511 - val_cos: 0.7635 - val_loss: 8.2374 - val_mse: 0.2495\n", "Epoch 12/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7637 - loss: 8.2014 - mse: 0.2498 - val_cos: 0.7641 - val_loss: 8.0899 - val_mse: 0.2478\n", "Epoch 13/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7639 - loss: 8.0752 - mse: 0.2456 - val_cos: 0.7645 - val_loss: 7.9631 - val_mse: 0.2451\n", "Epoch 14/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7638 - loss: 7.9306 - mse: 0.2457 - val_cos: 0.7665 - val_loss: 7.8171 - val_mse: 0.2403\n", "Epoch 15/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7642 - loss: 7.8377 - mse: 0.2410 - val_cos: 0.7663 - val_loss: 7.7359 - val_mse: 0.2385\n", "Epoch 16/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 185ms/step - cos: 0.7658 - loss: 7.6886 - mse: 0.2378 - val_cos: 0.7676 - val_loss: 7.6044 - val_mse: 0.2350\n", "Epoch 17/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 188ms/step - cos: 0.7643 - loss: 7.6359 - mse: 0.2369 - val_cos: 0.7659 - val_loss: 7.5199 - val_mse: 0.2345\n", "Epoch 18/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 184ms/step - cos: 0.7660 - loss: 7.5126 - mse: 0.2329 - val_cos: 0.7680 - val_loss: 7.4207 - val_mse: 0.2301\n", "Epoch 19/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7651 - loss: 7.4191 - mse: 0.2304 - val_cos: 0.7682 - val_loss: 7.3130 - val_mse: 0.2268\n", "Epoch 20/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7651 - loss: 7.3419 - mse: 0.2291 - val_cos: 0.7664 - val_loss: 7.2225 - val_mse: 0.2272\n", "Epoch 21/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7657 - loss: 7.2691 - mse: 0.2277 - val_cos: 0.7665 - val_loss: 7.1630 - val_mse: 0.2245\n", "Epoch 22/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7640 - loss: 7.2177 - mse: 0.2248 - val_cos: 0.7662 - val_loss: 7.0724 - val_mse: 0.2219\n", "Epoch 23/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7679 - loss: 7.0468 - mse: 0.2195 - val_cos: 0.7680 - val_loss: 6.9664 - val_mse: 0.2184\n", "Epoch 24/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7667 - loss: 6.9840 - mse: 0.2171 - val_cos: 0.7669 - val_loss: 6.9237 - val_mse: 0.2178\n", "Epoch 25/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7662 - loss: 6.9243 - mse: 0.2169 - val_cos: 0.7666 - val_loss: 6.8773 - val_mse: 0.2136\n", "Epoch 26/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7655 - loss: 6.8518 - mse: 0.2143 - val_cos: 0.7668 - val_loss: 6.7758 - val_mse: 0.2124\n", "Epoch 27/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7667 - loss: 6.7623 - mse: 0.2110 - val_cos: 0.7664 - val_loss: 6.7287 - val_mse: 0.2101\n", "Epoch 28/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 184ms/step - cos: 0.7676 - loss: 6.7556 - mse: 0.2077 - val_cos: 0.7678 - val_loss: 6.6686 - val_mse: 0.2059\n", "Epoch 29/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 191ms/step - cos: 0.7671 - loss: 6.6939 - mse: 0.2065 - val_cos: 0.7670 - val_loss: 6.6024 - val_mse: 0.2012\n", "Epoch 30/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 186ms/step - cos: 0.7660 - loss: 6.6050 - mse: 0.2017 - val_cos: 0.7678 - val_loss: 6.5662 - val_mse: 0.1994\n", "Epoch 31/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7667 - loss: 6.5798 - mse: 0.2007 - val_cos: 0.7677 - val_loss: 6.5317 - val_mse: 0.1979\n", "Epoch 32/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7669 - loss: 6.5304 - mse: 0.1988 - val_cos: 0.7691 - val_loss: 6.4457 - val_mse: 0.1951\n", "Epoch 33/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 184ms/step - cos: 0.7671 - loss: 6.4863 - mse: 0.1965 - val_cos: 0.7678 - val_loss: 6.4010 - val_mse: 0.1941\n", "Epoch 34/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 185ms/step - cos: 0.7666 - loss: 6.4082 - mse: 0.1940 - val_cos: 0.7678 - val_loss: 6.3757 - val_mse: 0.1933\n", "Epoch 35/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 189ms/step - cos: 0.7677 - loss: 6.3730 - mse: 0.1909 - val_cos: 0.7692 - val_loss: 6.3082 - val_mse: 0.1881\n", "Epoch 36/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7677 - loss: 6.3429 - mse: 0.1880 - val_cos: 0.7681 - val_loss: 6.2834 - val_mse: 0.1878\n", "Epoch 37/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7671 - loss: 6.2941 - mse: 0.1861 - val_cos: 0.7697 - val_loss: 6.2232 - val_mse: 0.1849\n", "Epoch 38/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7670 - loss: 6.2765 - mse: 0.1862 - val_cos: 0.7684 - val_loss: 6.1971 - val_mse: 0.1828\n", "Epoch 39/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7664 - loss: 6.2457 - mse: 0.1831 - val_cos: 0.7686 - val_loss: 6.1664 - val_mse: 0.1812\n", "Epoch 40/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7698 - loss: 6.1896 - mse: 0.1797 - val_cos: 0.7696 - val_loss: 6.1331 - val_mse: 0.1777\n", "Epoch 41/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7670 - loss: 6.1657 - mse: 0.1788 - val_cos: 0.7701 - val_loss: 6.1057 - val_mse: 0.1760\n", "Epoch 42/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7690 - loss: 6.0656 - mse: 0.1760 - val_cos: 0.7693 - val_loss: 6.0554 - val_mse: 0.1738\n", "Epoch 43/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7682 - loss: 6.0856 - mse: 0.1745 - val_cos: 0.7676 - val_loss: 6.0448 - val_mse: 0.1722\n", "Epoch 44/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7665 - loss: 6.0528 - mse: 0.1724 - val_cos: 0.7683 - val_loss: 6.0189 - val_mse: 0.1710\n", "Epoch 45/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 186ms/step - cos: 0.7691 - loss: 6.0253 - mse: 0.1699 - val_cos: 0.7685 - val_loss: 5.9979 - val_mse: 0.1665\n", "Epoch 46/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7679 - loss: 5.9833 - mse: 0.1665 - val_cos: 0.7681 - val_loss: 5.9251 - val_mse: 0.1675\n", "Epoch 47/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7680 - loss: 5.9603 - mse: 0.1664 - val_cos: 0.7698 - val_loss: 5.9433 - val_mse: 0.1651\n", "Epoch 48/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7701 - loss: 5.9152 - mse: 0.1653 - val_cos: 0.7703 - val_loss: 5.9054 - val_mse: 0.1632\n", "Epoch 49/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7682 - loss: 5.8829 - mse: 0.1632 - val_cos: 0.7692 - val_loss: 5.8782 - val_mse: 0.1611\n", "Epoch 50/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7683 - loss: 5.8843 - mse: 0.1602 - val_cos: 0.7705 - val_loss: 5.8711 - val_mse: 0.1598\n", "Epoch 51/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7687 - loss: 5.8453 - mse: 0.1596 - val_cos: 0.7680 - val_loss: 5.8498 - val_mse: 0.1603\n", "Epoch 52/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7685 - loss: 5.8001 - mse: 0.1577 - val_cos: 0.7699 - val_loss: 5.7597 - val_mse: 0.1563\n", "Epoch 53/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 173ms/step - cos: 0.7685 - loss: 5.7991 - mse: 0.1569 - val_cos: 0.7682 - val_loss: 5.7875 - val_mse: 0.1550\n", "Epoch 54/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7680 - loss: 5.7853 - mse: 0.1547 - val_cos: 0.7707 - val_loss: 5.7683 - val_mse: 0.1524\n", "Epoch 55/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7691 - loss: 5.7863 - mse: 0.1526 - val_cos: 0.7705 - val_loss: 5.7501 - val_mse: 0.1514\n", "Epoch 56/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7692 - loss: 5.7813 - mse: 0.1511 - val_cos: 0.7694 - val_loss: 5.7335 - val_mse: 0.1502\n", "Epoch 57/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7699 - loss: 5.7194 - mse: 0.1498 - val_cos: 0.7694 - val_loss: 5.7055 - val_mse: 0.1492\n", "Epoch 58/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7704 - loss: 5.6757 - mse: 0.1483 - val_cos: 0.7700 - val_loss: 5.6847 - val_mse: 0.1472\n", "Epoch 59/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7690 - loss: 5.7145 - mse: 0.1485 - val_cos: 0.7699 - val_loss: 5.6508 - val_mse: 0.1456\n", "Epoch 60/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7673 - loss: 5.6932 - mse: 0.1473 - val_cos: 0.7707 - val_loss: 5.6501 - val_mse: 0.1436\n", "Epoch 61/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7694 - loss: 5.6243 - mse: 0.1447 - val_cos: 0.7689 - val_loss: 5.6231 - val_mse: 0.1428\n", "Epoch 62/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7684 - loss: 5.6316 - mse: 0.1423 - val_cos: 0.7688 - val_loss: 5.5892 - val_mse: 0.1425\n", "Epoch 63/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7677 - loss: 5.6548 - mse: 0.1434 - val_cos: 0.7710 - val_loss: 5.5681 - val_mse: 0.1399\n", "Epoch 64/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 174ms/step - cos: 0.7680 - loss: 5.6244 - mse: 0.1421 - val_cos: 0.7698 - val_loss: 5.5903 - val_mse: 0.1400\n", "Epoch 65/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7681 - loss: 5.6289 - mse: 0.1406 - val_cos: 0.7687 - val_loss: 5.5534 - val_mse: 0.1409\n", "Epoch 66/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7688 - loss: 5.5736 - mse: 0.1403 - val_cos: 0.7702 - val_loss: 5.5605 - val_mse: 0.1376\n", "Epoch 67/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7700 - loss: 5.5189 - mse: 0.1380 - val_cos: 0.7702 - val_loss: 5.5123 - val_mse: 0.1363\n", "Epoch 68/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7687 - loss: 5.5515 - mse: 0.1369 - val_cos: 0.7691 - val_loss: 5.5241 - val_mse: 0.1370\n", "Epoch 69/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7702 - loss: 5.5545 - mse: 0.1357 - val_cos: 0.7699 - val_loss: 5.4955 - val_mse: 0.1362\n", "Epoch 70/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7690 - loss: 5.4659 - mse: 0.1352 - val_cos: 0.7703 - val_loss: 5.4853 - val_mse: 0.1337\n", "Epoch 71/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7681 - loss: 5.4991 - mse: 0.1344 - val_cos: 0.7683 - val_loss: 5.4826 - val_mse: 0.1333\n", "Epoch 72/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7681 - loss: 5.4836 - mse: 0.1327 - val_cos: 0.7693 - val_loss: 5.4592 - val_mse: 0.1316\n", "Epoch 73/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7702 - loss: 5.4963 - mse: 0.1315 - val_cos: 0.7706 - val_loss: 5.4468 - val_mse: 0.1308\n", "Epoch 74/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7696 - loss: 5.3915 - mse: 0.1302 - val_cos: 0.7698 - val_loss: 5.4245 - val_mse: 0.1298\n", "Epoch 75/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7706 - loss: 5.4288 - mse: 0.1288 - val_cos: 0.7695 - val_loss: 5.3944 - val_mse: 0.1290\n", "Epoch 76/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7676 - loss: 5.4072 - mse: 0.1294 - val_cos: 0.7708 - val_loss: 5.3982 - val_mse: 0.1279\n", "Epoch 77/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7688 - loss: 5.3941 - mse: 0.1292 - val_cos: 0.7698 - val_loss: 5.4304 - val_mse: 0.1282\n", "Epoch 78/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7692 - loss: 5.4147 - mse: 0.1282 - val_cos: 0.7707 - val_loss: 5.3892 - val_mse: 0.1265\n", "Epoch 79/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7703 - loss: 5.3819 - mse: 0.1260 - val_cos: 0.7696 - val_loss: 5.3757 - val_mse: 0.1265\n", "Epoch 80/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7691 - loss: 5.3872 - mse: 0.1262 - val_cos: 0.7688 - val_loss: 5.3662 - val_mse: 0.1262\n", "Epoch 81/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7701 - loss: 5.3129 - mse: 0.1245 - val_cos: 0.7701 - val_loss: 5.3568 - val_mse: 0.1245\n", "Epoch 82/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7690 - loss: 5.3379 - mse: 0.1245 - val_cos: 0.7694 - val_loss: 5.3354 - val_mse: 0.1242\n", "Epoch 83/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7687 - loss: 5.3438 - mse: 0.1245 - val_cos: 0.7719 - val_loss: 5.3168 - val_mse: 0.1228\n", "Epoch 84/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7681 - loss: 5.3040 - mse: 0.1235 - val_cos: 0.7715 - val_loss: 5.3151 - val_mse: 0.1220\n", "Epoch 85/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7685 - loss: 5.3504 - mse: 0.1237 - val_cos: 0.7695 - val_loss: 5.3025 - val_mse: 0.1231\n", "Epoch 86/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 174ms/step - cos: 0.7685 - loss: 5.3010 - mse: 0.1224 - val_cos: 0.7705 - val_loss: 5.3040 - val_mse: 0.1212\n", "Epoch 87/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7702 - loss: 5.2738 - mse: 0.1207 - val_cos: 0.7702 - val_loss: 5.2965 - val_mse: 0.1218\n", "Epoch 88/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7689 - loss: 5.2917 - mse: 0.1206 - val_cos: 0.7699 - val_loss: 5.2888 - val_mse: 0.1208\n", "Epoch 89/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7696 - loss: 5.3199 - mse: 0.1208 - val_cos: 0.7689 - val_loss: 5.2589 - val_mse: 0.1208\n", "Epoch 90/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7682 - loss: 5.2979 - mse: 0.1212 - val_cos: 0.7711 - val_loss: 5.2490 - val_mse: 0.1197\n", "Epoch 91/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 174ms/step - cos: 0.7701 - loss: 5.2316 - mse: 0.1198 - val_cos: 0.7712 - val_loss: 5.2642 - val_mse: 0.1194\n", "Epoch 92/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7691 - loss: 5.2812 - mse: 0.1199 - val_cos: 0.7704 - val_loss: 5.2346 - val_mse: 0.1190\n", "Epoch 93/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7688 - loss: 5.2679 - mse: 0.1191 - val_cos: 0.7693 - val_loss: 5.2493 - val_mse: 0.1184\n", "Epoch 94/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7690 - loss: 5.2947 - mse: 0.1185 - val_cos: 0.7703 - val_loss: 5.2468 - val_mse: 0.1179\n", "Epoch 95/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7697 - loss: 5.2224 - mse: 0.1174 - val_cos: 0.7699 - val_loss: 5.2175 - val_mse: 0.1174\n", "Epoch 96/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7679 - loss: 5.2491 - mse: 0.1178 - val_cos: 0.7706 - val_loss: 5.2031 - val_mse: 0.1174\n", "Epoch 97/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7704 - loss: 5.2146 - mse: 0.1168 - val_cos: 0.7690 - val_loss: 5.1959 - val_mse: 0.1174\n", "Epoch 98/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7698 - loss: 5.1986 - mse: 0.1171 - val_cos: 0.7694 - val_loss: 5.1951 - val_mse: 0.1169\n", "Epoch 99/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7685 - loss: 5.1510 - mse: 0.1173 - val_cos: 0.7692 - val_loss: 5.2092 - val_mse: 0.1164\n", "Epoch 100/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 174ms/step - cos: 0.7700 - loss: 5.1515 - mse: 0.1160 - val_cos: 0.7696 - val_loss: 5.2035 - val_mse: 0.1160\n", "Epoch 101/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7685 - loss: 5.2375 - mse: 0.1161 - val_cos: 0.7713 - val_loss: 5.1944 - val_mse: 0.1159\n", "Epoch 102/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7689 - loss: 5.1949 - mse: 0.1157 - val_cos: 0.7705 - val_loss: 5.1947 - val_mse: 0.1150\n", "Epoch 103/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7692 - loss: 5.1795 - mse: 0.1150 - val_cos: 0.7703 - val_loss: 5.1872 - val_mse: 0.1147\n", "Epoch 104/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7689 - loss: 5.1701 - mse: 0.1155 - val_cos: 0.7706 - val_loss: 5.1679 - val_mse: 0.1149\n", "Epoch 105/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 173ms/step - cos: 0.7685 - loss: 5.1989 - mse: 0.1154 - val_cos: 0.7689 - val_loss: 5.1848 - val_mse: 0.1153\n", "Epoch 106/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7691 - loss: 5.1822 - mse: 0.1145 - val_cos: 0.7703 - val_loss: 5.1448 - val_mse: 0.1142\n", "Epoch 107/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7695 - loss: 5.1392 - mse: 0.1146 - val_cos: 0.7708 - val_loss: 5.1465 - val_mse: 0.1139\n", "Epoch 108/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7692 - loss: 5.2153 - mse: 0.1145 - val_cos: 0.7705 - val_loss: 5.1640 - val_mse: 0.1136\n", "Epoch 109/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7690 - loss: 5.1583 - mse: 0.1140 - val_cos: 0.7689 - val_loss: 5.1519 - val_mse: 0.1142\n", "Epoch 110/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7700 - loss: 5.1384 - mse: 0.1134 - val_cos: 0.7688 - val_loss: 5.1593 - val_mse: 0.1139\n", "Epoch 111/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7695 - loss: 5.1484 - mse: 0.1134 - val_cos: 0.7709 - val_loss: 5.1299 - val_mse: 0.1132\n", "Epoch 112/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7699 - loss: 5.1683 - mse: 0.1126 - val_cos: 0.7698 - val_loss: 5.1275 - val_mse: 0.1131\n", "Epoch 113/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 171ms/step - cos: 0.7694 - loss: 5.1230 - mse: 0.1123 - val_cos: 0.7703 - val_loss: 5.1364 - val_mse: 0.1121\n", "Epoch 114/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7699 - loss: 5.1434 - mse: 0.1129 - val_cos: 0.7691 - val_loss: 5.1523 - val_mse: 0.1132\n", "Epoch 115/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7686 - loss: 5.1086 - mse: 0.1123 - val_cos: 0.7695 - val_loss: 5.1388 - val_mse: 0.1123\n", "Epoch 116/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7700 - loss: 5.1089 - mse: 0.1121 - val_cos: 0.7698 - val_loss: 5.1056 - val_mse: 0.1125\n", "Epoch 117/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7708 - loss: 5.0898 - mse: 0.1122 - val_cos: 0.7715 - val_loss: 5.1041 - val_mse: 0.1120\n", "Epoch 118/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 174ms/step - cos: 0.7688 - loss: 5.1048 - mse: 0.1123 - val_cos: 0.7698 - val_loss: 5.1103 - val_mse: 0.1117\n", "Epoch 119/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7690 - loss: 5.1339 - mse: 0.1123 - val_cos: 0.7707 - val_loss: 5.0992 - val_mse: 0.1114\n", "Epoch 120/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 174ms/step - cos: 0.7707 - loss: 5.0996 - mse: 0.1114 - val_cos: 0.7691 - val_loss: 5.1405 - val_mse: 0.1121\n", "Epoch 121/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7706 - loss: 5.0921 - mse: 0.1117 - val_cos: 0.7705 - val_loss: 5.1123 - val_mse: 0.1117\n", "Epoch 122/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7694 - loss: 5.1215 - mse: 0.1118 - val_cos: 0.7730 - val_loss: 5.1020 - val_mse: 0.1101\n", "Epoch 123/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 173ms/step - cos: 0.7694 - loss: 5.1185 - mse: 0.1113 - val_cos: 0.7713 - val_loss: 5.1067 - val_mse: 0.1113\n", "Epoch 124/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 175ms/step - cos: 0.7676 - loss: 5.1077 - mse: 0.1121 - val_cos: 0.7699 - val_loss: 5.1011 - val_mse: 0.1119\n", "Epoch 125/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7692 - loss: 5.1002 - mse: 0.1116 - val_cos: 0.7722 - val_loss: 5.0920 - val_mse: 0.1106\n", "Epoch 126/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7700 - loss: 5.0861 - mse: 0.1109 - val_cos: 0.7708 - val_loss: 5.0755 - val_mse: 0.1110\n", "Epoch 127/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7687 - loss: 5.1179 - mse: 0.1116 - val_cos: 0.7701 - val_loss: 5.0813 - val_mse: 0.1113\n", "Epoch 128/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7691 - loss: 5.0677 - mse: 0.1114 - val_cos: 0.7712 - val_loss: 5.0920 - val_mse: 0.1111\n", "Epoch 129/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 182ms/step - cos: 0.7693 - loss: 5.0750 - mse: 0.1109 - val_cos: 0.7697 - val_loss: 5.1003 - val_mse: 0.1117\n", "Epoch 130/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7696 - loss: 5.1088 - mse: 0.1111 - val_cos: 0.7700 - val_loss: 5.1090 - val_mse: 0.1112\n", "Epoch 131/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7710 - loss: 5.0843 - mse: 0.1103 - val_cos: 0.7703 - val_loss: 5.0754 - val_mse: 0.1116\n", "Epoch 132/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7694 - loss: 5.0816 - mse: 0.1113 - val_cos: 0.7695 - val_loss: 5.0800 - val_mse: 0.1109\n", "Epoch 133/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7690 - loss: 5.0900 - mse: 0.1110 - val_cos: 0.7691 - val_loss: 5.1067 - val_mse: 0.1107\n", "Epoch 134/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7687 - loss: 5.1286 - mse: 0.1116 - val_cos: 0.7706 - val_loss: 5.0937 - val_mse: 0.1104\n", "Epoch 135/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7699 - loss: 5.0638 - mse: 0.1106 - val_cos: 0.7692 - val_loss: 5.1000 - val_mse: 0.1115\n", "Epoch 136/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7696 - loss: 5.0928 - mse: 0.1109 - val_cos: 0.7711 - val_loss: 5.1196 - val_mse: 0.1105\n", "Epoch 137/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7688 - loss: 5.0861 - mse: 0.1113 - val_cos: 0.7689 - val_loss: 5.0883 - val_mse: 0.1112\n", "Epoch 138/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7705 - loss: 5.0776 - mse: 0.1104 - val_cos: 0.7706 - val_loss: 5.0706 - val_mse: 0.1108\n", "Epoch 139/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step - cos: 0.7708 - loss: 5.0805 - mse: 0.1106 - val_cos: 0.7694 - val_loss: 5.0848 - val_mse: 0.1114\n", "Epoch 140/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7709 - loss: 5.0705 - mse: 0.1100 - val_cos: 0.7696 - val_loss: 5.1025 - val_mse: 0.1108\n", "Epoch 141/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 184ms/step - cos: 0.7689 - loss: 5.0755 - mse: 0.1111 - val_cos: 0.7695 - val_loss: 5.0697 - val_mse: 0.1109\n", "Epoch 142/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 176ms/step - cos: 0.7693 - loss: 5.0860 - mse: 0.1110 - val_cos: 0.7698 - val_loss: 5.0901 - val_mse: 0.1108\n", "Epoch 143/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 182ms/step - cos: 0.7703 - loss: 5.0945 - mse: 0.1105 - val_cos: 0.7703 - val_loss: 5.0849 - val_mse: 0.1110\n", "Epoch 144/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step - cos: 0.7682 - loss: 5.0852 - mse: 0.1109 - val_cos: 0.7705 - val_loss: 5.0823 - val_mse: 0.1107\n", "Epoch 145/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 183ms/step - cos: 0.7700 - loss: 5.0820 - mse: 0.1099 - val_cos: 0.7691 - val_loss: 5.0824 - val_mse: 0.1114\n", "Epoch 146/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 184ms/step - cos: 0.7698 - loss: 5.1090 - mse: 0.1105 - val_cos: 0.7697 - val_loss: 5.0849 - val_mse: 0.1113\n", "Epoch 147/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 177ms/step - cos: 0.7699 - loss: 5.0637 - mse: 0.1106 - val_cos: 0.7702 - val_loss: 5.0996 - val_mse: 0.1107\n", "Epoch 148/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step - cos: 0.7708 - loss: 5.0515 - mse: 0.1101 - val_cos: 0.7695 - val_loss: 5.0811 - val_mse: 0.1111\n", "Epoch 149/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 192ms/step - cos: 0.7692 - loss: 5.0959 - mse: 0.1111 - val_cos: 0.7705 - val_loss: 5.1056 - val_mse: 0.1106\n", "Epoch 150/150\n", "\u001b[1m25/25\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 178ms/step - cos: 0.7685 - loss: 5.1008 - mse: 0.1110 - val_cos: 0.7713 - val_loss: 5.0885 - val_mse: 0.1105\n" ] } ], "source": [ "history = model.fit(\n", " train_ds,\n", " steps_per_epoch=steps_per_epoch,\n", " verbose=verbose,\n", " epochs=epochs,\n", " validation_data=val_ds,\n", " callbacks=model_callbacks,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize training history\n", "\n", "Let's visualize the training history to understand the model's performance during training. This will help to ensure the model is learning and not under or overfitting." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
INFO [VAL SET] MSE=0.0132, COS=0.9683 4122487501.py:2\n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mVAL SET\u001b[1m]\u001b[0m \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0132\u001b[0m, \u001b[33mCOS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.9683\u001b[0m \u001b]8;id=945728;file:///tmp/ipykernel_712291/4122487501.py\u001b\\\u001b[2m4122487501.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=966210;file:///tmp/ipykernel_712291/4122487501.py#2\u001b\\\u001b[2m2\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rst = nse.metrics.compute_metrics(metrics, test_y1, test_y2)\n", "logger.info(\"[VAL SET] \" + \", \".join([f\"{k.upper()}={v:.4f}\" for k, v in rst.items()]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export model to TF Lite / TFLM\n", "\n", "Once we have trained and evaluated the model, we need to export the model into a format that can be used for inference on the edge. Currently, we export the model to TensorFlow Lite flatbuffer format. This will also generate a C header file that can be used with TensorFlow Lite for Microcontrollers (TFLM).\n", "\n", "For this model, we will export as a 32-bit floating point model.\n", " \n", "__NOTE:__ We utilize `CONCRETE` mode to lower the model to concrete functions before converting. This is because TF (MLIR) fails to properly lower the dilated convolutional layers." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "W0000 00:00:1723835186.987318 712291 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.\n", "W0000 00:00:1723835186.987329 712291 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.\n" ] } ], "source": [ "converter = nse.converters.tflite.TfLiteKerasConverter(model=encoder)\n", "\n", "# Redirect stdout and stderr to devnull since TFLite converter is very verbose\n", "with open(os.devnull, 'w') as devnull:\n", " with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):\n", " tflite_content = converter.convert(\n", " test_x=test_x1,\n", " quantization=\"FP32\",\n", " io_type=\"float32\",\n", " mode=\"KERAS\",\n", " strict=False,\n", " verbose=verbose\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save TFLite model as both a file and C header" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "converter.export(\n", " tflite_path=job_dir / \"model.tflite\"\n", ")\n", "\n", "converter.export_header(\n", " header_path=job_dir / \"model.h\",\n", " name=\"model\",\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate TFLite model against TensorFlow model\n", "\n", "We will instantiate a tflite interpreter and evaluate the model on the test dataset. This will help us ensure that the model has been exported correctly and is ready for deployment." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" ] } ], "source": [ "tflite = nse.interpreters.tflite.TfLiteKerasInterpreter(tflite_content)\n", "tflite.compile()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved artifact at '/tmp/tmpserse9cu'. The following endpoints are available:\n", "\n", "* Endpoint 'serve'\n", " args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 800, 1), dtype=tf.float32, name='input')\n", "Output Type:\n", " TensorSpec(shape=(None, 128), dtype=tf.float32, name=None)\n", "Captures:\n", " 132164125518800: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125517648: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125516880: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125517840: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125518032: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125516688: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164116070672: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164116079888: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125515920: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164125516112: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109445904: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109445328: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109443024: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109440912: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109448976: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109449168: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109448784: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109449552: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109450320: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109450128: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120085136: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120084752: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164109449936: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120085904: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120086096: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120084560: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120087056: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120086480: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120088016: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120088592: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120089360: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120087440: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120088976: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120089744: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120091856: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120092624: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120091664: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120091472: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120092816: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120090512: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120093776: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120093200: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120094160: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120094928: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120096080: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120094544: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120095312: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120096272: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120097232: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120098000: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120093584: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120096656: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120098192: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120097040: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120099152: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120098576: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120099536: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120100688: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120740496: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120098960: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120100304: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120740304: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120742032: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120742800: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120741840: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120741648: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120742992: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120741456: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120743952: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120743376: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120744336: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120745104: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120746256: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120744720: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120745488: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120746448: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120747408: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120748176: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120743760: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120746832: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120748368: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120747216: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120749328: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120748752: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120749712: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120750480: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120751632: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120750096: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120750864: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120751056: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120753168: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120753936: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120752976: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120752784: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120754128: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120752592: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120755088: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120754896: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120755472: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117676304: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117676112: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164120756048: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117677264: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117677456: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117678416: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117679184: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117676496: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117677840: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117679376: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117678224: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117680336: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117679760: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117680720: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117681488: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117682640: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117681104: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117681872: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117682832: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117683792: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117684560: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117680144: TensorSpec(shape=(), dtype=tf.resource, name=None)\n", " 132164117683216: TensorSpec(shape=(), dtype=tf.resource, name=None)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "W0000 00:00:1723835188.716817 712291 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.\n", "W0000 00:00:1723835188.716827 712291 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.\n" ] } ], "source": [ "converter = nse.converters.tflite.TfLiteKerasConverter(model=encoder)\n", "\n", "tflite_content = converter.convert(\n", " test_x=test_x1,\n", " quantization=\"FP32\",\n", " io_type=\"float32\",\n", " mode=\"KERAS\",\n", " strict=False,\n", " verbose=verbose\n", ")\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "tflite = nse.interpreters.tflite.TfLiteKerasInterpreter(tflite_content)\n", "tflite.compile()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m 1/288\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m2s\u001b[0m 9ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m288/288\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n", "\u001b[1m288/288\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n" ] } ], "source": [ "y1_pred_tf = encoder.predict(test_x1)\n", "y2_pred_tf = encoder.predict(test_x2)\n", "\n", "y1_pred_tfl = tflite.predict(x=test_x1)\n", "y2_pred_tfl = tflite.predict(x=test_x2)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
INFO [TF METRICS] MSE=0.0132 COS=0.9683 2850812944.py:3\n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mTF METRICS\u001b[1m]\u001b[0m \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0132\u001b[0m \u001b[33mCOS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.9683\u001b[0m \u001b]8;id=395402;file:///tmp/ipykernel_712291/2850812944.py\u001b\\\u001b[2m2850812944.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=569945;file:///tmp/ipykernel_712291/2850812944.py#3\u001b\\\u001b[2m3\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
INFO [TFL METRICS] MSE=0.0132 COS=0.9683 2850812944.py:4\n", "\n" ], "text/plain": [ "\u001b[34mINFO \u001b[0m \u001b[1m[\u001b[0mTFL METRICS\u001b[1m]\u001b[0m \u001b[33mMSE\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.0132\u001b[0m \u001b[33mCOS\u001b[0m=\u001b[1;36m0\u001b[0m\u001b[1;36m.9683\u001b[0m \u001b]8;id=984174;file:///tmp/ipykernel_712291/2850812944.py\u001b\\\u001b[2m2850812944.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=540128;file:///tmp/ipykernel_712291/2850812944.py#4\u001b\\\u001b[2m4\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tf_rst = nse.metrics.compute_metrics(metrics, y1_pred_tf, y2_pred_tf)\n", "tfl_rst = nse.metrics.compute_metrics(metrics, y1_pred_tfl, y2_pred_tfl)\n", "logger.info(\"[TF METRICS] \" + \" \".join([f\"{k.upper()}={v:.4f}\" for k, v in tf_rst.items()]))\n", "logger.info(\"[TFL METRICS] \" + \" \".join([f\"{k.upper()}={v:.4f}\" for k, v in tfl_rst.items()]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ECG Foundation Demo\n", "\n", "Finally, we will showcase the foundation model by running across lots of patients and plotting via t-SNE to view the embeddings. This will help us understand how the model is clustering the data and if it is learning useful features." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "