Skip to content

demo

Classes

Functions

demo

demo(params: HKTaskParams)

Run translate demo for model

Parameters:

Source code in heartkit/tasks/translate/demo.py
def demo(params: HKTaskParams):
    """Run translate demo for model

    Args:
        params (HKTaskParams): Task parameters
    """
    bg_color = "rgba(38,42,50,1.0)"
    primary_color = "#11acd5"
    # secondary_color = "#ce6cff"
    tertiary_color = "rgb(234,52,36)"
    quaternary_color = "rgb(92,201,154)"
    plotly_template = "plotly_dark"

    params.demo_size = params.demo_size or 10 * params.sampling_rate

    # Load backend inference engine
    runner = BackendFactory.get(params.backend)(params=params)

    # feat_shape = (params.demo_size, 1)
    # class_shape = (params.demo_size, 1)

    # Load data
    datasets = [DatasetFactory.get(ds.name)(**ds.params) for ds in params.datasets]
    ds = random.choice(datasets)
    print(ds)

    # ds_gen = get_data_generator(
    #     ds, frame_size=params.demo_size, samples_per_patient=5, target_rate=params.sampling_rate
    # )

    # ds_gen = ds_gen(patient_generator=helia.utils.uniform_id_generator(ds.get_test_patient_ids(), repeat=False))

    # x, y = next(ds_gen)

    # x, y = prepare(
    #     (x, y),
    #     sample_rate=params.sampling_rate,
    #     preprocesses=params.preprocesses,
    #     augmentations=params.augmentations,
    #     spec=ds_spec,
    #     num_classes=params.num_classes,
    # )
    x, y = None, None
    x = x.flatten()
    y = y.flatten()

    # Run inference
    runner.open()
    logger.debug("Running inference")
    y_pred = np.zeros(x.size, dtype=np.float32)

    cos_sim_diff = 0
    prev_cos_sim = 0

    x_input = x.copy()
    for trial in range(8):
        for i in tqdm(range(0, x.size, params.frame_size), desc="Inference"):
            if i + params.frame_size > x.size:
                start, stop = x.size - params.frame_size, x.size
            else:
                start, stop = i, i + params.frame_size
            xx = x_input[start:stop]
            runner.set_inputs(xx)
            runner.perform_inference()
            yy = runner.get_outputs()
            y_pred[start:stop] = yy.flatten()
        # END FOR
        x_input = y_pred.copy()
        cos_sim = np.dot(y, y_pred) / (np.linalg.norm(y) * np.linalg.norm(y_pred))
        cos_sim_diff = cos_sim - prev_cos_sim
        prev_cos_sim = cos_sim
        logger.debug(f"Trial {trial + 1}: Cosine Similarity: {cos_sim:.2%} (diff: {cos_sim_diff:.2%})")
        if cos_sim_diff < 1e-3:
            break
    # END FOR

    # END FOR
    runner.close()
    # Report
    logger.debug("Generating report")
    ts = np.arange(0, x.size) / params.sampling_rate

    # Compute cosine similarity
    cos_sim_orig = np.dot(y, x) / (np.linalg.norm(y) * np.linalg.norm(x))
    cos_sim = np.dot(y, y_pred) / (np.linalg.norm(y) * np.linalg.norm(y_pred))
    logger.debug(f"Before Cosine Similarity: {cos_sim_orig:.2%}")
    logger.debug(f"After Cosine Similarity: {cos_sim:.2%}")

    fig = make_subplots(
        rows=3,
        cols=1,
        shared_xaxes=True,
        horizontal_spacing=0.1,
        vertical_spacing=0.05,
    )

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=y,
            name="REF",
            mode="lines",
            line=dict(color=tertiary_color, width=3),
        ),
        row=1,
        col=1,
    )
    fig.update_yaxes(title_text="REF", row=1, col=1)

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=x,
            name="INPUT",
            mode="lines",
            line=dict(color=primary_color, width=3),
        ),
        row=2,
        col=1,
    )
    fig.update_yaxes(title_text="INPUT", row=2, col=1)

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=y_pred,
            name="OUTPUT",
            mode="lines",
            line=dict(color=quaternary_color, width=3),
        ),
        row=3,
        col=1,
    )
    fig.update_yaxes(title_text="OUTPUT", row=3, col=1)

    fig.add_annotation(
        x=1,
        y=0.3,
        text=f"CoSim: {cos_sim_orig:.2%}",
        showarrow=False,
        xref="paper",
        yref="paper",
        align="right",
        font=dict(
            family="Menlo",
            size=24,
        ),
    )

    fig.add_annotation(
        x=1,
        y=-0.08,
        text=f"CoSim: {cos_sim:.2%}",
        showarrow=False,
        xref="paper",
        yref="paper",
        align="right",
        font=dict(
            family="Menlo",
            size=24,
        ),
    )

    fig.update_xaxes(title_text="Time (s)", row=3, col=1)
    # fig.update_yaxes(title_text="SIGNAL", row=1, col=1)

    fig.update_layout(
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        template=plotly_template,
        height=800,
        plot_bgcolor=bg_color,
        paper_bgcolor=bg_color,
        margin=dict(l=10, r=10, t=80, b=60),
        title="heartKIT: Translate Demo",
    )

    fig.write_html(params.job_dir / "demo.html", include_plotlyjs="cdn", full_html=False)
    logger.debug(f"Report saved to {params.job_dir / 'demo.html'}")

    if params.display_report:
        fig.show()