Skip to content

demo

Classes

Functions

demo

demo(params: HKTaskParams)

Run denoise demo.

Parameters:

Source code in heartkit/tasks/denoise/demo.py
def demo(params: HKTaskParams):
    """Run denoise demo.

    Args:
        params (HKTaskParams): Task parameters
    """
    logger = helia.utils.setup_logger(__name__, level=params.verbose)

    plot_theme = setup_plotting()

    params.demo_size = params.demo_size or 10 * params.sampling_rate

    # Load backend inference engine
    runner = BackendFactory.get(params.backend)(params=params)

    # Load data
    datasets = [DatasetFactory.get(ds.name)(cacheable=False, **ds.params) for ds in params.datasets]
    ds = random.choice(datasets)

    ds_gen = ds.signal_generator(
        patient_generator=helia.utils.uniform_id_generator(ds.get_test_patient_ids(), repeat=False),
        frame_size=params.demo_size,
        samples_per_patient=5,
        target_rate=params.sampling_rate,
    )
    x = next(ds_gen)
    x = np.nan_to_num(x, neginf=0, posinf=0).astype(np.float32)
    x = np.reshape(x, (-1, 1))
    y_act = x.copy()

    preprocessor = create_augmentation_pipeline(
        params.preprocesses,
        sampling_rate=params.sampling_rate,
    )
    augmenter = create_augmentation_pipeline(
        params.augmentations,
        sampling_rate=params.sampling_rate,
    )

    x = preprocessor(augmenter(x)).numpy()
    y_act = preprocessor(y_act).numpy()

    x = x.flatten()
    y_act = y_act.flatten()

    # Run inference
    runner.open()
    logger.debug("Running inference")
    y_pred = np.zeros(x.size, dtype=np.float32)

    cos_sim_diff = 0
    prev_cos_sim = 0

    x_input = x.copy()
    for trial in range(8):
        for i in tqdm(range(0, x.size, params.frame_size), desc="Inference"):
            if i + params.frame_size > x.size:
                start, stop = x.size - params.frame_size, x.size
            else:
                start, stop = i, i + params.frame_size
            xx = x_input[start:stop]
            runner.set_inputs(xx)
            runner.perform_inference()
            yy = runner.get_outputs()
            y_pred[start:stop] = yy.flatten()
        # END FOR
        x_input = y_pred.copy()
        cos_sim = np.dot(y_act, y_pred) / (np.linalg.norm(y_act) * np.linalg.norm(y_pred))
        cos_sim_diff = cos_sim - prev_cos_sim
        prev_cos_sim = cos_sim
        logger.debug(f"Trial {trial + 1}: Cosine Similarity: {cos_sim:.2%} (diff: {cos_sim_diff:.2%})")
        if cos_sim_diff < 1e-3:
            break
    # END FOR

    # END FOR
    runner.close()
    # Report
    logger.debug("Generating report")
    ts = np.arange(0, x.size) / params.sampling_rate

    # Compute cosine similarity
    cos_sim_orig = np.dot(y_act, x) / (np.linalg.norm(y_act) * np.linalg.norm(x))
    cos_sim = np.dot(y_act, y_pred) / (np.linalg.norm(y_act) * np.linalg.norm(y_pred))
    logger.debug(f"Before Cosine Similarity: {cos_sim_orig:.2%}")
    logger.debug(f"After Cosine Similarity: {cos_sim:.2%}")

    fig = make_subplots(
        rows=3,
        cols=1,
        shared_xaxes=True,
        horizontal_spacing=0.1,
        vertical_spacing=0.05,
    )

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=y_act,
            name="REF",
            mode="lines",
            line=dict(color=plot_theme.tertiary_color, width=3),
        ),
        row=1,
        col=1,
    )
    fig.update_yaxes(title_text="REF", row=1, col=1)

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=x,
            name="NOISE",
            mode="lines",
            line=dict(color=plot_theme.primary_color, width=3),
        ),
        row=2,
        col=1,
    )
    fig.update_yaxes(title_text="NOISE", row=2, col=1)

    fig.add_trace(
        go.Scatter(
            x=ts,
            y=y_pred,
            name="CLEAN",
            mode="lines",
            line=dict(color=plot_theme.quaternary_color, width=3),
        ),
        row=3,
        col=1,
    )
    fig.update_yaxes(title_text="CLEAN", row=3, col=1)

    fig.add_annotation(
        x=1,
        y=0.3,
        text=f"CoSim: {cos_sim_orig:.2%}",
        showarrow=False,
        xref="paper",
        yref="paper",
        align="right",
        font=dict(
            family="Menlo",
            size=24,
        ),
    )

    fig.add_annotation(
        x=1,
        y=-0.08,
        text=f"CoSim: {cos_sim:.2%}",
        showarrow=False,
        xref="paper",
        yref="paper",
        align="right",
        font=dict(
            family="Menlo",
            size=24,
        ),
    )

    fig.update_xaxes(title_text="Time (s)", row=3, col=1)

    fig.update_layout(
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        template=plot_theme.plotly_template,
        height=800,
        plot_bgcolor=plot_theme.bg_color,
        paper_bgcolor=plot_theme.bg_color,
        margin=dict(l=10, r=10, t=80, b=60),
        title="heartKIT: Denoising Demo",
    )

    fig.write_html(params.job_dir / "demo.html", include_plotlyjs="cdn", full_html=False)
    logger.debug(f"Report saved to {params.job_dir / 'demo.html'}")

    if params.display_report:
        fig.show()

    fig, ax = plt.subplots(3, 1, figsize=(10, 6), sharex=True)

    ax[0].plot(ts, y_act, color=plot_theme.tertiary_color, linewidth=3)
    ax[0].set_ylabel("REF")

    ax[1].plot(ts, x, color=plot_theme.primary_color, linewidth=3)
    ax[1].set_ylabel("NOISE")

    ax[2].plot(ts, y_pred, color=plot_theme.quaternary_color, linewidth=3)
    ax[2].set_ylabel("CLEAN")
    ax[2].set_xlabel("Time (s)")

    # Add annotations
    ax[1].annotate(
        f"COS: {cos_sim_orig:.0%}",
        xy=(0.99, 0.05),
        xycoords="axes fraction",
        xytext=(0, 0),
        textcoords="offset points",
        ha="right",
        va="bottom",
        fontsize=14,
        color=plot_theme.fg_color,
        fontweight="bold",
    )
    ax[2].annotate(
        f"COS: {cos_sim:.0%}",
        xy=(0.99, 0.05),
        xycoords="axes fraction",
        xytext=(0, 0),
        textcoords="offset points",
        ha="right",
        va="bottom",
        fontsize=14,
        color=plot_theme.fg_color,
        fontweight="bold",
    )

    fig.tight_layout()
    fig.savefig(params.job_dir / "demo.png")