Skip to content

demo

Classes

Functions

demo

demo(params: HKTaskParams)

Run demo for model

Parameters:

Source code in heartkit/tasks/foundation/demo.py
def demo(params: HKTaskParams):
    """Run demo for model

    Args:
        params (HKTaskParams): Task parameters
    """
    logger = helia.utils.setup_logger(__name__, level=params.verbose)
    plot_theme = setup_plotting()

    feat_shape = (params.frame_size, 1)
    num_pts = min(params.batch_size, 256)

    # Load backend inference engine
    runner = BackendFactory.get(params.backend)(params=params)

    # load datasets and randomly select one
    dataset = random.choice(params.datasets)
    ds = DatasetFactory.get(dataset.name)(**dataset.params)
    dataloader = FoundationTaskFactory.get(dataset.name)(
        ds=ds,
        frame_size=params.frame_size,
        sampling_rate=params.sampling_rate,
    )
    patients: npt.NDArray = ds.get_test_patient_ids()
    patients = np.random.choice(patients, size=num_pts, replace=False)

    x1 = np.zeros((num_pts, *feat_shape), dtype=np.float32)
    x2 = np.zeros((num_pts, *feat_shape), dtype=np.float32)

    # For each patient, generate TGT_LEN samples
    i = 0
    for patient in patients:
        for xx1, xx2 in dataloader.patient_data_generator(patient, 1):
            x1[i] = xx1
            x2[i] = xx2
            i += 1
        # END FOR
    # END FOR

    # Run inference
    runner.open()
    logger.debug("Running inference")
    y1 = np.zeros((num_pts, params.num_classes), dtype=np.float32)
    y2 = np.zeros((num_pts, params.num_classes), dtype=np.float32)
    for i in tqdm(range(0, num_pts), desc="Inference"):
        runner.set_inputs(x1[i])
        runner.perform_inference()
        y1[i] = runner.get_outputs()
        runner.set_inputs(x2[i])
        runner.perform_inference()
        y2[i] = runner.get_outputs()
    # END FOR
    runner.close()

    def cossim(x, y):
        return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

    # END DEF

    pts_cosim = np.zeros((num_pts, num_pts), dtype=np.float32)
    # Compute cosine similarity between pairs of patients
    for i in range(num_pts):
        for j in range(num_pts):
            pts_cosim[i, j] = cossim(y1[i], y1[j])
        # END FOR
    # END FOR

    # Get average cosine similarity down the diagonal
    mu_tgt_cos = np.mean(np.diag(pts_cosim))
    mu_other = np.mean(pts_cosim[np.eye(num_pts) == 0])

    logger.info(f"Average cosine similarity (target): {mu_tgt_cos:0.2%}")
    logger.info(f"Average cosine similarity (other): {mu_other:0.2%}")

    # Compute t-SNE
    tsne = TSNE(n_components=2, random_state=0, n_iter=5000, perplexity=75)
    x_tsne = tsne.fit_transform(np.concatenate([y1, y2]))

    fig = make_subplots(
        rows=1,
        cols=2,
        specs=[
            [{"type": "xy"}, {"type": "heatmap"}],
        ],
        subplot_titles=("t-SNE", "Cosine Similarity"),
    )

    pts = np.concatenate([patients, patients]).tolist()
    fig.add_trace(
        go.Scatter(
            x=x_tsne[:, 0],
            y=x_tsne[:, 1],
            mode="markers",
            customdata=pts,
            hovertemplate="Patient %{customdata}<extra></extra>",
            marker_color=[px.colors.qualitative.Dark24[i % 24] for i in pts],
            marker_size=8,
            name="t-SNE",
        ),
        row=1,
        col=1,
    )

    # Add heatmap of cosine similarity
    fig.add_trace(
        go.Heatmap(
            z=pts_cosim,
            x=list(range(num_pts)),
            y=list(range(num_pts)),
            colorscale="Plotly3",
            showscale=True,
            zmin=0,
            zmax=1,
        ),
        row=1,
        col=2,
    )

    fig.update_layout(
        template=plot_theme.plotly_template,
        height=600,
        plot_bgcolor=plot_theme.bg_color,
        paper_bgcolor=plot_theme.bg_color,
        margin=dict(l=10, r=10, t=80, b=80),
        legend=dict(groupclick="toggleitem"),
        title="heartKIT: Foundation Demo",
    )
    fig.write_html(params.job_dir / "demo.html", include_plotlyjs="cdn", full_html=True)
    logger.debug(f"Report saved to {params.job_dir / 'demo.html'}")

    if params.display_report:
        fig.show()