Skip to content

demo

Classes

Functions

demo

demo(params: HKTaskParams)

Run demo for model

Parameters:

Source code in heartkit/tasks/rhythm/demo.py
def demo(params: HKTaskParams):
    """Run demo for model

    Args:
        params (HKTaskParams): Task parameters
    """
    logger = helia.utils.setup_logger(__name__, level=params.verbose)

    plot_theme = setup_plotting()

    params.demo_size = params.demo_size or 2 * params.frame_size

    # Load backend inference engine
    runner = BackendFactory.get(params.backend)(params=params)

    # Load data
    # classes = sorted(list(set(params.class_map.values())))
    class_names = params.class_names or [f"Class {i}" for i in range(params.num_classes)]

    feat_shape = (params.frame_size, 1)

    datasets = [DatasetFactory.get(ds.name)(**ds.params) for ds in params.datasets]
    ds = random.choice(datasets)

    ds_gen = ds.signal_generator(
        patient_generator=helia.utils.uniform_id_generator(ds.get_test_patient_ids(), repeat=False),
        frame_size=params.demo_size,
        samples_per_patient=5,
        target_rate=params.sampling_rate,
    )
    x = next(ds_gen)

    augmenter = create_augmentation_pipeline(
        params.preprocesses + params.augmentations, sampling_rate=params.sampling_rate
    )

    # Run inference
    runner.open()
    logger.debug("Running inference")
    y_preds: list[tuple[int, int, int, float]] = []  # (start, stop, class, prob)
    for i in tqdm(range(0, x.shape[0], params.frame_size), desc="Inference"):
        if i + params.frame_size > x.shape[0]:
            start, stop = x.shape[0] - params.frame_size, x.shape[0]
        else:
            start, stop = i, i + params.frame_size
        xx = x[start:stop]
        xx = xx.reshape(feat_shape)
        xx = augmenter(xx, training=False)
        runner.set_inputs(xx)
        runner.perform_inference()
        yy = runner.get_outputs()
        y_prob = np.exp(yy) / np.sum(np.exp(yy), axis=-1, keepdims=True)
        y_pred = np.argmax(y_prob, axis=-1)
        y_prob = y_prob[y_pred]
        if y_prob < params.threshold:
            y_pred = -1
        y_preds.append((start, stop - 1, y_pred, y_prob))
    # END FOR
    runner.close()

    # Report
    logger.debug("Generating report")
    tod = datetime.datetime(2025, 5, 24, random.randint(12, 23), 00)
    ts = np.array([tod + datetime.timedelta(seconds=i / params.sampling_rate) for i in range(x.shape[0])])

    fig = make_subplots(
        rows=1,
        cols=1,
        specs=[
            [{"colspan": 1, "type": "xy", "secondary_y": True}],
        ],
        subplot_titles=(None, None),
        horizontal_spacing=0.05,
        vertical_spacing=0.1,
    )

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=x,
            name="ECG",
            mode="lines",
            line=dict(color=plot_theme.primary_color, width=2),
            showlegend=False,
        ),
        row=1,
        col=1,
        secondary_y=False,
    )

    for i, (start, stop, pred, prob) in enumerate(y_preds):
        if pred < 0:
            label = f"Inconclusive ({prob:0.0%})"
        else:
            label = f"{class_names[pred]} ({prob:0.0%})"
        color = plot_theme.colors[pred % len(plot_theme.colors)]
        if i > 0 and y_preds[i - 1][1] >= start:
            start = y_preds[i - 1][1] + 1
        fig.add_vrect(
            x0=ts[start],
            x1=ts[stop],
            annotation_text=label,
            fillcolor=color,
            opacity=0.25,
            line_width=2,
            line_color=color,
            row=1,
            col=1,
            secondary_y=False,
        )

    fig.update_layout(
        template=plot_theme.plotly_template,
        height=600,
        plot_bgcolor=plot_theme.bg_color,
        paper_bgcolor=plot_theme.bg_color,
        margin=dict(l=10, r=10, t=80, b=80),
        legend=dict(groupclick="toggleitem"),
        title="heartKIT: Rhythm Demo",
    )
    fig.write_html(params.job_dir / "demo.html", include_plotlyjs="cdn", full_html=False)
    logger.debug(f"Report saved to {params.job_dir / 'demo.html'}")

    # Matplotlib version

    # Make ts in seconds
    ts = np.arange(x.shape[0]) / params.sampling_rate
    fig, ax = plt.subplots(figsize=(9, 4))
    ax.plot(ts, x, color=plot_theme.primary_color, linewidth=2)
    for i, (start, stop, pred, prob) in enumerate(y_preds):
        if pred < 0:
            label = f"Inconclusive ({prob:0.0%})"
        else:
            label = f"{class_names[pred]} ({prob:0.0%})"
        color = plot_theme.colors[pred % len(plot_theme.colors)]
        ax.axvspan(ts[start], ts[stop], color=color, alpha=0.25, label=label)
    ax.set_title("heartKIT: Rhythm Demo")
    ax.set_xlabel("Time (sec)")
    ax.set_ylabel("ECG")
    ax.legend(loc="upper right", bbox_to_anchor=(1, 1), frameon=False)
    fig.tight_layout()
    fig.savefig(params.job_dir / "demo.png")

    if params.display_report:
        fig.show()

    # cleanup
    keras.utils.clear_session()
    for ds in datasets:
        ds.close()