Skip to content

demo

Classes

Functions

demo

demo(params: HKTaskParams)

Run diagnostic demo.

Parameters:

Source code in heartkit/tasks/diagnostic/demo.py
def demo(params: HKTaskParams):
    """Run diagnostic demo.

    Args:
        params (HKTaskParams): Task parameters
    """

    bg_color = "rgba(38,42,50,1.0)"
    primary_color = "#11acd5"
    secondary_color = "#ce6cff"
    plotly_template = "plotly_dark"

    params.demo_size = params.demo_size or 2 * params.frame_size

    # Load backend inference engine
    runner = BackendFactory.get(params.backend)(params=params)

    # classes = sorted(list(set(params.class_map.values())))
    class_names = params.class_names or [f"Class {i}" for i in range(params.num_classes)]

    feat_shape = (params.frame_size, 1)
    # class_shape = (params.num_classes,)

    # ds_spec = (
    #     tf.TensorSpec(shape=feat_shape, dtype=tf.float32),
    #     tf.TensorSpec(shape=class_shape, dtype=tf.int32),
    # )

    # Load data
    datasets = [DatasetFactory.get(ds.name)(**ds.params) for ds in params.datasets]
    ds = random.choice(datasets)

    ds_gen = ds.signal_generator(
        patient_generator=helia.utils.uniform_id_generator(ds.get_test_patient_ids(), repeat=False),
        frame_size=params.demo_size,
        samples_per_patient=5,
        target_rate=params.sampling_rate,
    )
    x = next(ds_gen)

    # Run inference
    runner.open()
    logger.debug("Running inference")
    y_pred = np.zeros(x.shape[0], dtype=np.int32)
    for i in tqdm(range(0, x.shape[0], params.frame_size), desc="Inference"):
        if i + params.frame_size > x.shape[0]:
            start, stop = x.shape[0] - params.frame_size, x.shape[0]
        else:
            start, stop = i, i + params.frame_size
        # xx = preprocess(x[start:stop], sample_rate=params.sampling_rate, preprocesses=params.preprocesses)
        xx = x[start:stop]
        xx = xx.reshape(feat_shape)
        runner.set_inputs(xx)
        runner.perform_inference()
        yy = runner.get_outputs()
        y_pred[start:stop] = np.argmax(yy, axis=-1).flatten()
    # END FOR
    runner.close()

    # Report
    logger.debug("Generating report")
    tod = datetime.datetime(2025, 5, 24, random.randint(12, 23), 00)
    ts = np.array([tod + datetime.timedelta(seconds=i / params.sampling_rate) for i in range(x.shape[0])])

    pred_bounds = np.concatenate(([0], np.diff(y_pred).nonzero()[0] + 1, [y_pred.size - 1]))

    fig = make_subplots(
        rows=1,
        cols=1,
        specs=[
            [{"colspan": 1, "type": "xy", "secondary_y": True}],
        ],
        subplot_titles=(None, None),
        horizontal_spacing=0.05,
        vertical_spacing=0.1,
    )

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=x,
            name="ECG",
            mode="lines",
            line=dict(color=primary_color, width=2),
            showlegend=False,
        ),
        row=1,
        col=1,
        secondary_y=False,
    )

    for i in range(1, len(pred_bounds)):
        start, stop = pred_bounds[i - 1], pred_bounds[i]
        pred_class = y_pred[start]
        if pred_class <= 0:
            continue
        fig.add_vrect(
            x0=ts[start],
            x1=ts[stop],
            annotation_text=class_names[pred_class],
            fillcolor=secondary_color,
            opacity=0.25,
            line_width=2,
            line_color=secondary_color,
            row=1,
            col=1,
            secondary_y=False,
        )

    fig.update_layout(
        template=plotly_template,
        height=600,
        plot_bgcolor=bg_color,
        paper_bgcolor=bg_color,
        margin=dict(l=10, r=10, t=80, b=80),
        legend=dict(groupclick="toggleitem"),
        title="heartKIT: Diagnostic Demo",
    )

    fig.write_html(params.job_dir / "demo.html", include_plotlyjs="cdn", full_html=True)
    logger.debug(f"Report saved to {params.job_dir / 'demo.html'}")

    if params.display_report:
        fig.show()