Skip to content

demo

Classes

Functions

demo

demo(params: HKTaskParams)

Run segmentation demo.

Parameters:

Source code in heartkit/tasks/segmentation/demo.py
def demo(params: HKTaskParams):
    """Run segmentation demo.

    Args:
        params (HKTaskParams): Demo parameters
    """
    logger = helia.utils.setup_logger(__name__, level=params.verbose)
    plot_theme = setup_plotting()

    signal_type = getattr(params, "signal_type", "ECG").upper()  # ECG or PPG
    params.demo_size = params.demo_size or params.frame_size

    # Load backend inference engine
    runner = BackendFactory.get(params.backend)(params)

    classes = sorted(set(params.class_map.values()))
    class_names = params.class_names or [f"Class {i}" for i in range(params.num_classes)]

    feat_shape = (params.frame_size, 1)
    class_shape = (params.frame_size, params.num_classes)

    datasets = [DatasetFactory.get(ds.name)(cacheable=False, **ds.params) for ds in params.datasets]
    ds = random.choice(datasets)

    ds_gen = ds.signal_generator(
        patient_generator=helia.utils.uniform_id_generator(ds.get_test_patient_ids(), repeat=False),
        frame_size=params.demo_size,
        samples_per_patient=5,
        target_rate=params.sampling_rate,
    )
    x = next(ds_gen)

    augmenter = create_augmentation_pipeline(
        augmentations=params.augmentations + params.preprocesses,
        sampling_rate=params.sampling_rate,
    )

    runner.open()
    logger.debug("Running inference")
    y_pred = np.zeros(x.size, dtype=np.int32)
    for i in tqdm(range(0, x.size, params.frame_size), desc="Inference"):
        if i + params.frame_size > x.size:
            start, stop = x.size - params.frame_size, x.size
        else:
            start, stop = i, i + params.frame_size
        xx = x[start:stop]
        yy = np.zeros(shape=class_shape, dtype=np.int32)
        xx = xx.reshape(feat_shape)
        xx = augmenter(xx, training=True)
        runner.set_inputs(xx)
        runner.perform_inference()
        x[start:stop] = xx.numpy().squeeze()
        yy = runner.get_outputs()
        y_pred[start:stop] = np.argmax(yy, axis=-1).flatten()
    # END FOR
    runner.close()

    # Report
    logger.debug("Generating report")
    ts = np.arange(0, x.size) / params.sampling_rate

    if signal_type == "ECG":
        # Extract R peaks from QRS segments
        pred_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0] + 1, [y_pred.size - 1]))
        peaks = []
        for i in range(1, len(pred_bounds)):
            start, stop = pred_bounds[i - 1], pred_bounds[i]
            duration = 1000 * (stop - start) / params.sampling_rate
            if y_pred[start] == params.class_map.get(HKSegment.qrs, -1) and (duration > 20):
                peaks.append(start + np.argmax(np.abs(x[start:stop])))
            # END IF
        # END FOR
        peaks = np.array(peaks)

    elif signal_type == "PPG":
        # peaks = pk.ppg.find_peaks(x, sample_rate=params.sampling_rate)
        pred_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0] + 1, [y_pred.size - 1]))
        peaks = []
        for i in range(1, len(pred_bounds)):
            start, stop = pred_bounds[i - 1], pred_bounds[i]
            duration = 1000 * (stop - start) / params.sampling_rate
            if y_pred[start] == params.class_map.get(HKSegment.systolic, -1) and (duration > 100):
                peaks.append(start + np.argmax(x[start:stop]))
            # END IF
        # END FOR
        peaks = np.array(peaks)
    else:
        raise ValueError(f"Unknown signal type: {signal_type}")

    band_names = ["VLF", "LF", "HF", "VHF"]
    bands = [(0.0033, 0.04), (0.04, 0.15), (0.15, 0.4), (0.4, 0.5)]

    # Compute R-R intervals
    rri = pk.ecg.compute_rr_intervals(peaks)
    mask = pk.ecg.filter_rr_intervals(rri, sample_rate=params.sampling_rate)
    rri_ms = 1000 * rri / params.sampling_rate
    # Compute metrics
    if (rri.size <= 2) or (mask.sum() / mask.size > 0.80):
        logger.warning("High percentage of RR intervals were filtered out")
        hr_bpm = 0
        hrv_td = pk.hrv.HrvTimeMetrics()
        hrv_fd = pk.hrv.HrvFrequencyMetrics(bands=[pk.hrv.HrvFrequencyBandMetrics() for b in bands])
    else:
        hr_bpm = 60 / (np.nanmean(rri[mask == 0]) / params.sampling_rate)
        hrv_td = pk.hrv.compute_hrv_time(rri[mask == 0], sample_rate=params.sampling_rate)
        hrv_fd = pk.hrv.compute_hrv_frequency(
            peaks[mask == 0], rri[mask == 0], bands=bands, sample_rate=params.sampling_rate
        )
    # END IF

    fig = make_subplots(
        rows=2,
        cols=3,
        specs=[
            [{"colspan": 3, "type": "xy", "secondary_y": True}, None, None],
            [{"type": "xy"}, {"type": "bar"}, {"type": "table"}],
        ],
        subplot_titles=(f"{signal_type} Plot", "IBI Poincare Plot", "HRV Frequency Bands"),
        horizontal_spacing=0.1,
        vertical_spacing=0.2,
    )

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=x,
            name="SIGNAL",
            mode="lines",
            line=dict(color=plot_theme.primary_color, width=2),
        ),
        row=1,
        col=1,
        secondary_y=False,
    )

    for i, peak in enumerate(peaks):
        color = "red" if mask[i] else "white"
        fig.add_vline(
            x=ts[peak],
            line_width=1,
            line_dash="dash",
            line_color=color,
            annotation={"text": "R-Peak", "textangle": -90, "font_color": color},
            row=1,
            col=1,
            secondary_y=False,
        )
    fig.update_xaxes(title_text="Time (s)", row=1, col=1)
    fig.update_yaxes(title_text=f"{signal_type}", row=1, col=1)

    for i, label in enumerate(classes):
        if label < 0:
            continue
        fig.add_trace(
            go.Scatter(
                x=ts,
                y=np.where(y_pred == label, x, np.nan),
                name=class_names[i],
                mode="lines",
                line_width=2,
            ),
            row=1,
            col=1,
            secondary_y=False,
        )
    # END FOR

    fig.add_trace(
        go.Bar(
            x=np.array([b.total_power for b in hrv_fd.bands]) / hrv_fd.total_power,
            y=band_names,
            marker_color=[
                plot_theme.primary_color,
                plot_theme.secondary_color,
                plot_theme.tertiary_color,
                plot_theme.quaternary_color,
            ],
            orientation="h",
            showlegend=False,
        ),
        row=2,
        col=2,
    )
    fig.update_xaxes(title_text="Normalized Power", row=2, col=2)

    fig.add_trace(
        go.Scatter(
            x=rri_ms[:-1],
            y=rri_ms[1:],
            mode="markers",
            marker_size=10,
            showlegend=False,
            customdata=np.arange(1, rri_ms.size),
            hovertemplate="RRn: %{x:.1f} ms<br>RRn+1: %{y:.1f} ms<br>n: %{customdata}",
            marker_color=plot_theme.secondary_color,
        ),
        row=2,
        col=1,
        secondary_y=False,
    )
    if rri_ms.size > 2:
        rr_min, rr_max = np.nanmin(rri_ms) - 20, np.nanmax(rri_ms) + 20
    else:
        rr_min, rr_max = 0, 2000
    fig.add_shape(
        type="line",
        layer="below",
        y0=rr_min,
        y1=rr_max,
        x0=rr_min,
        x1=rr_max,
        line=dict(color="white", width=2),
        row=2,
        col=1,
    )
    fig.update_xaxes(title_text="RRn (ms)", range=[rr_min, rr_max], row=2, col=1)
    fig.update_yaxes(title_text="RRn+1 (ms)", range=[rr_min, rr_max], row=2, col=1)

    fig.add_trace(
        go.Table(
            header=dict(
                values=["Heart Rate <br> Variability Metric", "<br> Value"],
                font=dict(size=16, color="white"),
                height=60,
                fill_color=plot_theme.primary_color,
                align=["left"],
            ),
            cells=dict(
                values=[
                    ["Heart Rate", "NN Mean", "NN St. Dev", "SD RMS"],
                    [
                        f"{hr_bpm:.0f} BPM",
                        f"{hrv_td.mean_nn:.1f} ms",
                        f"{hrv_td.sd_nn:.1f} ms",
                        f"{hrv_td.rms_sd:.1f}",
                    ],
                ],
                font=dict(size=14),
                height=40,
                fill_color=plot_theme.bg_color,
                align=["left"],
            ),
        ),
        row=2,
        col=3,
    )

    fig.update_layout(
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        template=plot_theme.plotly_template,
        height=800,
        plot_bgcolor=plot_theme.bg_color,
        paper_bgcolor=plot_theme.bg_color,
        margin=dict(l=10, r=10, t=80, b=80),
        title="heartKIT: Segmentation Demo",
    )

    fig.write_html(params.job_dir / "demo.html", include_plotlyjs="cdn", full_html=False)
    logger.debug(f"Report saved to {params.job_dir / 'demo.html'}")

    if params.display_report:
        fig.show()

    # Reproduce above in matplotlib
    # First row is full width
    # Bottom row has 3 columns

    fig = plt.figure(layout="constrained", figsize=(10, 6))
    gs = GridSpec(2, 3, figure=fig)
    ax1 = fig.add_subplot(gs[0, :])
    ax21 = fig.add_subplot(gs[1, 0])
    ax22 = fig.add_subplot(gs[1, 1])
    ax23 = fig.add_subplot(gs[1, 2])

    ax1.plot(ts, x, color=plot_theme.primary_color)

    for i, label in enumerate(classes):
        if label < 0:
            continue
        color = plot_theme.colors[i % len(plot_theme.colors)]
        ax1.plot(ts, np.where(y_pred == label, x, np.nan), label=class_names[i], color=color)
    # END FOR
    # Add R peaks
    for i, peak in enumerate(peaks):
        color = "red" if mask[i] else "white"
        ax1.axvline(x=ts[peak], color=color, linestyle="--", label="R-Peak" if i == 0 else None)
    ax1.set_title(f"{signal_type} Plot")
    ax1.set_xlabel("Time (s)")
    ax1.set_ylabel(f"{signal_type}")
    # ax1.legend(loc="upper right", ncols=len(class_names))

    # IBI Poincare Plot
    ax21.scatter(rri_ms[:-1], rri_ms[1:], color=plot_theme.secondary_color)
    # Add 45-degree line
    rr_min, rr_max = np.nanmin(rri_ms) - 20, np.nanmax(rri_ms) + 20
    ax21.plot([rr_min, rr_max], [rr_min, rr_max], color="white", linestyle="--")
    ax21.set_title("IBI Poincare Plot")
    ax21.set_xlabel("RRn (ms)")
    ax21.set_ylabel("RRn+1 (ms)")

    # HRV Frequency Bands
    ax22.barh(band_names, np.array([b.total_power for b in hrv_fd.bands]) / hrv_fd.total_power)
    ax22.set_title("HRV Frequency Bands")
    ax22.set_xlabel("Normalized Power")

    # HRV Metrics
    ax23.axis("off")
    table_data = [
        ["Heart Rate", f"{hr_bpm:.0f} BPM"],
        ["NN Mean", f"{hrv_td.mean_nn:.1f} ms"],
        ["NN St. Dev", f"{hrv_td.sd_nn:.1f} ms"],
        ["SD RMS", f"{hrv_td.rms_sd:.1f}"],
    ]
    table = ax23.table(cellText=table_data, cellLoc="left", loc="center", cellColours=[[plot_theme.bg_color] * 2] * 4)
    table.scale(1.5, 3)
    table.auto_set_font_size(True)
    ax23.set_title("HRV Metrics")
    for key, cell in table.get_celld().items():
        cell.set_edgecolor(plot_theme.fg_color)

    fig.tight_layout()
    fig.savefig(params.job_dir / "demo.png")