Skip to content

HeartKit

heartkit.cli

parse_content(cls, content)

Parse file or raw content into Pydantic model.

Parameters:

  • cls (B) –

    Pydantic model subclasss

  • content (str) –

    File path or raw content

Returns:

  • B ( B ) –

    Pydantic model subclass instance

Source code in heartkit/cli.py
def parse_content(cls: Type[B], content: str) -> B:
    """Parse file or raw content into Pydantic model.

    Args:
        cls (B): Pydantic model subclasss
        content (str): File path or raw content

    Returns:
        B: Pydantic model subclass instance
    """
    if os.path.isfile(content):
        with open(content, "r", encoding="utf-8") as f:
            content = f.read()

    return cls.model_validate_json(json_data=content)

run()

Run CLI.

Source code in heartkit/cli.py
def run():
    """Run CLI."""
    cli()

heartkit.defines

AugmentationParams

Bases: BaseModel

Augmentation parameters

Source code in heartkit/defines.py
class AugmentationParams(BaseModel, extra="allow"):
    """Augmentation parameters"""

    name: str
    params: dict[str, tuple[float | int, float | int]]

DatasetParams

Bases: BaseModel

Dataset parameters

Source code in heartkit/defines.py
class DatasetParams(BaseModel, extra="allow"):
    """Dataset parameters"""

    name: str
    params: dict[str, Any] = Field(default_factory=dict, description="Parameters")
    weight: float = Field(1, description="Dataset weight")

HKDemoParams

Bases: BaseModel

HK demo command params

Source code in heartkit/defines.py
class HKDemoParams(BaseModel, extra="allow"):
    """HK demo command params"""

    job_dir: Path = Field(default_factory=lambda: Path(tempfile.gettempdir()), description="Job output directory")
    # Dataset arguments
    ds_path: Path = Field(default_factory=lambda: Path("./datasets"), description="Dataset directory")
    datasets: list[DatasetParams] = Field(default_factory=list, description="Datasets")
    sampling_rate: int = Field(250, description="Target sampling rate (Hz)")
    frame_size: int = Field(1250, description="Frame size")
    num_classes: int = Field(1, description="# of classes")
    class_map: dict[int, int] = Field(default_factory=lambda: {1: 1}, description="Class/label mapping")
    class_names: list[str] | None = Field(default=None, description="Class names")
    preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses")
    augmentations: list[AugmentationParams] = Field(default_factory=list, description="Augmentations")
    # Model arguments
    model_file: Path | None = Field(None, description="Path to save model file (.keras)")
    backend: str = Field("pc", description="Backend")
    demo_size: int | None = Field(1000, description="# samples for demo")
    display_report: bool = Field(True, description="Display report")
    # Extra arguments
    seed: int | None = Field(None, description="Random state seed")
    model_config = ConfigDict(protected_namespaces=())

    def model_post_init(self, __context: Any) -> None:
        """Post init hook"""

        if self.model_file and len(self.model_file.parts) == 1:
            self.model_file = self.job_dir / self.model_file

model_post_init(__context)

Post init hook

Source code in heartkit/defines.py
def model_post_init(self, __context: Any) -> None:
    """Post init hook"""

    if self.model_file and len(self.model_file.parts) == 1:
        self.model_file = self.job_dir / self.model_file

HKDownloadParams

Bases: BaseModel

Download command params

Source code in heartkit/defines.py
class HKDownloadParams(BaseModel, extra="allow"):
    """Download command params"""

    job_dir: Path = Field(default_factory=tempfile.gettempdir, description="Job output directory")
    ds_path: Path = Field(default_factory=Path, description="Dataset root directory")
    datasets: list[str] = Field(default_factory=list, description="Datasets")
    progress: bool = Field(True, description="Display progress bar")
    force: bool = Field(False, description="Force download dataset- overriding existing files")
    data_parallelism: int = Field(
        default_factory=lambda: os.cpu_count() or 1,
        description="# of data loaders running in parallel",
    )

HKExportParams

Bases: BaseModel

Export command params

Source code in heartkit/defines.py
class HKExportParams(BaseModel, extra="allow"):
    """Export command params"""

    job_dir: Path = Field(default_factory=lambda: Path(tempfile.gettempdir()), description="Job output directory")
    # Dataset arguments
    ds_path: Path = Field(default_factory=lambda: Path("./datasets"), description="Dataset directory")
    datasets: list[DatasetParams] = Field(default_factory=list, description="Datasets")
    sampling_rate: int = Field(250, description="Target sampling rate (Hz)")
    frame_size: int = Field(1250, description="Frame size")
    num_classes: int = Field(3, description="# of classes")
    class_map: dict[int, int] = Field(default_factory=lambda: {1: 1}, description="Class/label mapping")
    class_names: list[str] | None = Field(default=None, description="Class names")
    test_samples_per_patient: int | list[int] = Field(100, description="# test samples per patient")
    test_patients: float | None = Field(None, description="# or proportion of patients for testing")
    test_size: int = Field(100_000, description="# samples for testing")
    test_file: Path | None = Field(None, description="Path to load/store pickled test file")
    preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses")
    augmentations: list[AugmentationParams] = Field(default_factory=list, description="Augmentations")
    model_file: Path | None = Field(None, description="Path to save model file (.keras)")
    threshold: float | None = Field(None, description="Model output threshold")
    val_acc_threshold: float | None = Field(0.98, description="Validation accuracy threshold")
    use_logits: bool = Field(True, description="Use logits output or softmax")
    quantization: QuantizationParams = Field(default_factory=QuantizationParams, description="Quantization parameters")
    tflm_var_name: str = Field("g_model", description="TFLite Micro C variable name")
    tflm_file: Path | None = Field(None, description="Path to copy TFLM header file (e.g. ./model_buffer.h)")
    data_parallelism: int = Field(
        default_factory=lambda: os.cpu_count() or 1,
        description="# of data loaders running in parallel",
    )
    model_config = ConfigDict(protected_namespaces=())

    def model_post_init(self, __context: Any) -> None:
        """Post init hook"""

        if self.test_file and len(self.test_file.parts) == 1:
            self.test_file = self.job_dir / self.test_file

        if self.model_file and len(self.model_file.parts) == 1:
            self.model_file = self.job_dir / self.model_file

        if self.tflm_file and len(self.tflm_file.parts) == 1:
            self.tflm_file = self.job_dir / self.tflm_file

model_post_init(__context)

Post init hook

Source code in heartkit/defines.py
def model_post_init(self, __context: Any) -> None:
    """Post init hook"""

    if self.test_file and len(self.test_file.parts) == 1:
        self.test_file = self.job_dir / self.test_file

    if self.model_file and len(self.model_file.parts) == 1:
        self.model_file = self.job_dir / self.model_file

    if self.tflm_file and len(self.tflm_file.parts) == 1:
        self.tflm_file = self.job_dir / self.tflm_file

HKMode

Bases: StrEnum

HeartKit Mode

Source code in heartkit/defines.py
class HKMode(StrEnum):
    """HeartKit Mode"""

    download = "download"
    train = "train"
    evaluate = "evaluate"
    export = "export"
    demo = "demo"

HKTestParams

Bases: BaseModel

Test command params

Source code in heartkit/defines.py
class HKTestParams(BaseModel, extra="allow"):
    """Test command params"""

    job_dir: Path = Field(default_factory=lambda: Path(tempfile.gettempdir()), description="Job output directory")
    # Dataset arguments
    ds_path: Path = Field(default_factory=lambda: Path("./datasets"), description="Dataset directory")
    datasets: list[DatasetParams] = Field(default_factory=list, description="Datasets")
    sampling_rate: int = Field(250, description="Target sampling rate (Hz)")
    frame_size: int = Field(1250, description="Frame size")
    num_classes: int = Field(1, description="# of classes")
    class_map: dict[int, int] = Field(default_factory=lambda: {1: 1}, description="Class/label mapping")
    class_names: list[str] | None = Field(default=None, description="Class names")
    test_samples_per_patient: int | list[int] = Field(1000, description="# test samples per patient")
    test_patients: float | None = Field(None, description="# or proportion of patients for testing")
    test_size: int = Field(200_000, description="# samples for testing")
    test_file: Path | None = Field(None, description="Path to load/store pickled test file")
    preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses")
    augmentations: list[AugmentationParams] = Field(default_factory=list, description="Augmentations")
    # Model arguments
    model_file: Path | None = Field(None, description="Path to save model file (.keras)")
    threshold: float | None = Field(None, description="Model output threshold")
    # Extra arguments
    seed: int | None = Field(None, description="Random state seed")
    data_parallelism: int = Field(
        default_factory=lambda: os.cpu_count() or 1,
        description="# of data loaders running in parallel",
    )
    model_config = ConfigDict(protected_namespaces=())

    def model_post_init(self, __context: Any) -> None:
        """Post init hook"""

        if self.test_file and len(self.test_file.parts) == 1:
            self.test_file = self.job_dir / self.test_file

        if self.model_file and len(self.model_file.parts) == 1:
            self.model_file = self.job_dir / self.model_file

model_post_init(__context)

Post init hook

Source code in heartkit/defines.py
def model_post_init(self, __context: Any) -> None:
    """Post init hook"""

    if self.test_file and len(self.test_file.parts) == 1:
        self.test_file = self.job_dir / self.test_file

    if self.model_file and len(self.model_file.parts) == 1:
        self.model_file = self.job_dir / self.model_file

HKTrainParams

Bases: BaseModel

Train command params

Source code in heartkit/defines.py
class HKTrainParams(BaseModel, extra="allow"):
    """Train command params"""

    name: str = Field("experiment", description="Experiment name")
    job_dir: Path = Field(default_factory=lambda: Path(tempfile.gettempdir()), description="Job output directory")
    # Dataset arguments
    ds_path: Path = Field(default_factory=lambda: Path("./datasets"), description="Dataset directory")
    datasets: list[DatasetParams] = Field(default_factory=list, description="Datasets")
    sampling_rate: int = Field(250, description="Target sampling rate (Hz)")
    frame_size: int = Field(1250, description="Frame size")
    num_classes: int = Field(1, description="# of classes")
    class_map: dict[int, int] = Field(default_factory=lambda: {1: 1}, description="Class/label mapping")
    class_names: list[str] | None = Field(default=None, description="Class names")
    samples_per_patient: int | list[int] = Field(1000, description="# train samples per patient")
    val_samples_per_patient: int | list[int] = Field(1000, description="# validation samples per patient")
    train_patients: float | None = Field(None, description="# or proportion of patients for training")
    val_patients: float | None = Field(None, description="# or proportion of patients for validation")
    val_file: Path | None = Field(None, description="Path to load/store pickled validation file")
    val_size: int | None = Field(None, description="# samples for validation")
    # Model arguments
    resume: bool = Field(False, description="Resume training")
    architecture: ModelArchitecture | None = Field(default=None, description="Custom model architecture")
    model_file: Path | None = Field(None, description="Path to save model file (.keras)")
    threshold: float | None = Field(None, description="Model output threshold")

    weights_file: Path | None = Field(None, description="Path to a checkpoint weights to load")
    quantization: QuantizationParams = Field(default_factory=QuantizationParams, description="Quantization parameters")
    # Training arguments
    lr_rate: float = Field(1e-3, description="Learning rate")
    lr_cycles: int = Field(3, description="Number of learning rate cycles")
    lr_decay: float = Field(0.9, description="Learning rate decay")
    class_weights: Literal["balanced", "fixed"] = Field("fixed", description="Class weights")
    label_smoothing: float = Field(0, description="Label smoothing")
    batch_size: int = Field(32, description="Batch size")
    buffer_size: int = Field(100, description="Buffer size")
    epochs: int = Field(50, description="Number of epochs")
    steps_per_epoch: int = Field(10, description="Number of steps per epoch")
    val_metric: Literal["loss", "acc", "f1"] = Field("loss", description="Performance metric")
    # Preprocessing/Augmentation arguments
    preprocesses: list[PreprocessParams] = Field(default_factory=list, description="Preprocesses")
    augmentations: list[AugmentationParams] = Field(default_factory=list, description="Augmentations")
    # Extra arguments
    seed: int | None = Field(None, description="Random state seed")
    data_parallelism: int = Field(
        default_factory=lambda: os.cpu_count() or 1,
        description="# of data loaders running in parallel",
    )
    model_config = ConfigDict(protected_namespaces=())

    def model_post_init(self, __context: Any) -> None:
        """Post init hook"""

        if self.val_file and len(self.val_file.parts) == 1:
            self.val_file = self.job_dir / self.val_file

        if self.model_file and len(self.model_file.parts) == 1:
            self.model_file = self.job_dir / self.model_file

        if self.weights_file and len(self.weights_file.parts) == 1:
            self.weights_file = self.job_dir / self.weights_file

model_post_init(__context)

Post init hook

Source code in heartkit/defines.py
def model_post_init(self, __context: Any) -> None:
    """Post init hook"""

    if self.val_file and len(self.val_file.parts) == 1:
        self.val_file = self.job_dir / self.val_file

    if self.model_file and len(self.model_file.parts) == 1:
        self.model_file = self.job_dir / self.model_file

    if self.weights_file and len(self.weights_file.parts) == 1:
        self.weights_file = self.job_dir / self.weights_file

ModelArchitecture

Bases: BaseModel

Model architecture parameters

Source code in heartkit/defines.py
class ModelArchitecture(BaseModel, extra="allow"):
    """Model architecture parameters"""

    name: str
    params: dict[str, Any] = Field(default_factory=dict, description="Parameters")

PreprocessParams

Bases: BaseModel

Preprocessing parameters

Source code in heartkit/defines.py
class PreprocessParams(BaseModel, extra="allow"):
    """Preprocessing parameters"""

    name: str
    params: dict[str, Any]

QuantizationParams

Bases: BaseModel

Quantization parameters

Source code in heartkit/defines.py
class QuantizationParams(BaseModel, extra="allow"):
    """Quantization parameters"""

    enabled: bool = Field(False, description="Enable quantization")
    qat: bool = Field(False, description="Enable quantization aware training (QAT)")
    ptq: bool = Field(False, description="Enable post training quantization (PTQ)")
    input_type: str | None = Field(None, description="Input type")
    output_type: str | None = Field(None, description="Output type")
    supported_ops: list[str] | None = Field(None, description="Supported ops")

heartkit.metrics

compute_iou(y_true, y_pred, average='micro')

Compute IoU

Parameters:

Returns:

  • float ( float ) –

    IoU

Source code in heartkit/metrics.py
def compute_iou(
    y_true: npt.NDArray,
    y_pred: npt.NDArray,
    average: Literal["micro", "macro", "weighted"] = "micro",
) -> float:
    """Compute IoU

    Args:
        y_true (npt.NDArray): Y true
        y_pred (npt.NDArray): Y predicted

    Returns:
        float: IoU
    """
    return jaccard_score(y_true.flatten(), y_pred.flatten(), average=average)

confusion_matrix_plot(y_true, y_pred, labels, save_path=None, normalize=False, **kwargs)

Generate confusion matrix plot via matplotlib/seaborn

Parameters:

  • y_true (NDArray) –

    True y labels

  • y_pred (NDArray) –

    Predicted y labels

  • labels (list[str]) –

    Label names

  • save_path (str | None, default: None ) –

    Path to save plot. Defaults to None.

Returns:

  • tuple[Figure, Axes] | None

    tuple[plt.Figure, plt.Axes] | None: Figure and axes

Source code in heartkit/metrics.py
def confusion_matrix_plot(
    y_true: npt.NDArray,
    y_pred: npt.NDArray,
    labels: list[str],
    save_path: os.PathLike | None = None,
    normalize: Literal["true", "pred", "all"] | None = False,
    **kwargs,
) -> tuple[plt.Figure, plt.Axes] | None:
    """Generate confusion matrix plot via matplotlib/seaborn

    Args:
        y_true (npt.NDArray): True y labels
        y_pred (npt.NDArray): Predicted y labels
        labels (list[str]): Label names
        save_path (str | None): Path to save plot. Defaults to None.

    Returns:
        tuple[plt.Figure, plt.Axes] | None: Figure and axes
    """

    cm = confusion_matrix(y_true, y_pred)
    cmn = cm
    ann = True
    fmt = "g"
    if normalize:
        cmn = confusion_matrix(y_true, y_pred, normalize=normalize)
        ann = np.asarray([f"{c:g}{os.linesep}{nc:.2%}" for c, nc in zip(cm.flatten(), cmn.flatten())]).reshape(cm.shape)
        fmt = ""
    # END IF
    fig, ax = plt.subplots(figsize=kwargs.get("figsize", (10, 8)))
    sns.heatmap(cmn, xticklabels=labels, yticklabels=labels, annot=ann, fmt=fmt, ax=ax)
    ax.set_xlabel("Prediction")
    ax.set_ylabel("Label")
    if save_path:
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)
        return None
    # END IF
    return fig, ax

f1(y_true, y_prob, multiclass=False, threshold=None)

Compute F1 scores

Parameters:

  • y_true ( npt.NDArray) –

    Y true

  • y_prob ( npt.NDArray) –

    2D matrix with class probs

  • multiclass (bool, default: False ) –

    If multiclass. Defaults to False.

  • threshold (float, default: None ) –

    Decision threshold for multiclass. Defaults to None.

Returns:

Source code in heartkit/metrics.py
def f1(
    y_true: npt.NDArray,
    y_prob: npt.NDArray,
    multiclass: bool = False,
    threshold: float = None,
) -> npt.NDArray | float:
    """Compute F1 scores

    Args:
        y_true ( npt.NDArray): Y true
        y_prob ( npt.NDArray): 2D matrix with class probs
        multiclass (bool, optional): If multiclass. Defaults to False.
        threshold (float, optional): Decision threshold for multiclass. Defaults to None.

    Returns:
        npt.NDArray|float: F1 scores
    """
    if y_prob.ndim != 2:
        raise ValueError("y_prob must be a 2d matrix with class probabilities for each sample")
    if y_true.ndim == 1:  # we assume that y_true is sparse (consequently, multiclass=False)
        if multiclass:
            raise ValueError("if y_true cannot be sparse and multiclass at the same time")
        depth = y_prob.shape[1]
        y_true = _one_hot(y_true, depth)
    if multiclass:
        if threshold is None:
            threshold = 0.5
        y_pred = y_prob >= threshold
    else:
        y_pred = y_prob >= np.max(y_prob, axis=1)[:, None]
    return f1_score(y_true, y_pred, average="macro")

f_max(y_true, y_prob, thresholds=None)

Compute F max source: https://github.com/helme/ecg_ptbxl_benchmarking

Parameters:

  • y_true (NDArray) –

    Y True

  • y_prob (NDArray) –

    Y probs

  • thresholds (float | list[float] | None, default: None ) –

    Thresholds. Defaults to None.

Returns:

Source code in heartkit/metrics.py
def f_max(
    y_true: npt.NDArray,
    y_prob: npt.NDArray,
    thresholds: float | list[float] | None = None,
) -> tuple[float, float]:
    """Compute F max
    source: https://github.com/helme/ecg_ptbxl_benchmarking

    Args:
        y_true (npt.NDArray): Y True
        y_prob (npt.NDArray): Y probs
        thresholds (float|list[float]|None, optional): Thresholds. Defaults to None.

    Returns:
        tuple[float, float]: F1 and thresholds
    """
    if thresholds is None:
        thresholds = np.linspace(0, 1, 100)
    pr, rc = macro_precision_recall(y_true, y_prob, thresholds)
    f1s = (2 * pr * rc) / (pr + rc)
    i = np.nanargmax(f1s)
    return f1s[i], thresholds[i]

macro_precision_recall(y_true, y_prob, thresholds)

Compute macro precision and recall source: https://github.com/helme/ecg_ptbxl_benchmarking

Parameters:

  • y_true (NDArray) –

    True y labels

  • y_prob (NDArray) –

    Predicted y labels

  • thresholds (NDArray) –

    Thresholds

Returns:

Source code in heartkit/metrics.py
def macro_precision_recall(
    y_true: npt.NDArray, y_prob: npt.NDArray, thresholds: npt.NDArray
) -> tuple[np.float_, np.float_]:
    """Compute macro precision and recall
    source: https://github.com/helme/ecg_ptbxl_benchmarking

    Args:
        y_true (npt.NDArray): True y labels
        y_prob (npt.NDArray): Predicted y labels
        thresholds (npt.NDArray): Thresholds

    Returns:
       tuple[np.float_, np.float_]: Precision and recall
    """
    y_true = np.repeat(y_true[None, :, :], len(thresholds), axis=0)
    y_prob = np.repeat(y_prob[None, :, :], len(thresholds), axis=0)
    y_pred = y_prob >= thresholds[:, None, None]

    # compute true positives
    tp = np.sum(np.logical_and(y_true, y_pred), axis=2)

    # compute macro average precision handling all warnings
    with np.errstate(divide="ignore", invalid="ignore"):
        den = np.sum(y_pred, axis=2)
        precision = tp / den
        precision[den == 0] = np.nan
        with warnings.catch_warnings():  # for nan slices
            warnings.simplefilter("ignore", category=RuntimeWarning)
            av_precision = np.nanmean(precision, axis=1)

    # compute macro average recall
    recall = tp / np.sum(y_true, axis=2)
    av_recall = np.mean(recall, axis=1)

    return av_precision, av_recall

multi_f1(y_true, y_prob)

Compute multi-class F1

Parameters:

  • y_true (NDArray) –

    True y labels

  • y_prob (NDArray) –

    Predicted y labels

Returns:

Source code in heartkit/metrics.py
def multi_f1(y_true: npt.NDArray, y_prob: npt.NDArray) -> npt.NDArray | float:
    """Compute multi-class F1

    Args:
        y_true (npt.NDArray): True y labels
        y_prob (npt.NDArray): Predicted y labels

    Returns:
        npt.NDArray|float: F1 score
    """
    return f1(y_true, y_prob, multiclass=True, threshold=0.5)

multilabel_confusion_matrix_plot(y_true, y_pred, labels, save_path=None, normalize=False, max_cols=5, **kwargs)

Generate multilabel confusion matrix plot via matplotlib/seaborn

Parameters:

  • y_true (NDArray) –

    True y labels

  • y_pred (NDArray) –

    Predicted y labels

  • labels (list[str]) –

    Label names

  • save_path (str | None, default: None ) –

    Path to save plot. Defaults to None.

  • normalize (Literal['true', 'pred', 'all'] | None, default: False ) –

    Normalize. Defaults to False.

  • max_cols (int, default: 5 ) –

    Max columns. Defaults to 5.

Returns:

  • tuple[Figure, Axes] | None

    tuple[plt.Figure, plt.Axes] | None: Figure and axes

Source code in heartkit/metrics.py
def multilabel_confusion_matrix_plot(
    y_true: npt.NDArray,
    y_pred: npt.NDArray,
    labels: list[str],
    save_path: os.PathLike | None = None,
    normalize: Literal["true", "pred", "all"] | None = False,
    max_cols: int = 5,
    **kwargs,
) -> tuple[plt.Figure, plt.Axes] | None:
    """Generate multilabel confusion matrix plot via matplotlib/seaborn

    Args:
        y_true (npt.NDArray): True y labels
        y_pred (npt.NDArray): Predicted y labels
        labels (list[str]): Label names
        save_path (str | None): Path to save plot. Defaults to None.
        normalize (Literal["true", "pred", "all"] | None): Normalize. Defaults to False.
        max_cols (int): Max columns. Defaults to 5.

    Returns:
        tuple[plt.Figure, plt.Axes] | None: Figure and axes
    """
    cms = multilabel_confusion_matrix(y_true, y_pred)
    ncols = min(len(labels), max_cols)
    nrows = len(labels) // ncols + (len(labels) % ncols > 0)
    width = 10
    hdim = width / ncols
    fig, ax = plt.subplots(figsize=kwargs.get("figsize", (width, hdim * nrows)), nrows=nrows, ncols=ncols)
    for i, label in enumerate(labels):
        r = i // ncols
        c = i % ncols
        ann, fmt = True, "g"
        cm = cms[i]  # cm will be 2x2
        cmn = cm
        if normalize == "true":
            cmn = cmn / cmn.sum(axis=1, keepdims=True)
        elif normalize == "pred":
            cmn = cmn / cmn.sum(axis=0, keepdims=True)
        elif normalize == "all":
            cmn = cmn / cmn.sum()
        cmn = np.nan_to_num(cmn)
        if normalize:
            ann = np.asarray([f"{c:g}{os.linesep}{nc:.2%}" for c, nc in zip(cm.flatten(), cmn.flatten())]).reshape(
                cm.shape
            )
            fmt = ""
        # END IF
        cm_ax = ax[c] if nrows <= 1 else ax[r][c]
        sns.heatmap(cmn, xticklabels=False, yticklabels=False, annot=ann, fmt=fmt, ax=cm_ax)
        cm_ax.set_title(label)
    # END FOR
    # Remove unused subplots
    for i in range(len(labels), nrows * ncols):
        r = i // ncols
        c = i % ncols
        cm_ax = ax[c] if nrows <= 1 else ax[r][c]
        fig.delaxes(cm_ax)
    # END FOR
    fig.text(0.5, 0.04, "Prediction", ha="center")
    fig.text(0, 0.5, "Label", va="center", rotation="vertical")
    fig.tight_layout()
    if save_path:
        fig.savefig(save_path, bbox_inches="tight")
        plt.close(fig)
        return None
    # END IF
    return fig, ax

px_plot_confusion_matrix(y_true, y_pred, labels, normalize=False, save_path=None, title=None, width=None, height=400, bg_color='rgba(38,42,50,1.0)')

Generate confusion matrix plot via plotly

Parameters:

  • y_true (NDArray) –

    True y labels

  • y_pred (NDArray) –

    Predicted y labels

  • labels (list[str]) –

    Label names

  • normalize (Literal['true', 'pred', 'all'] | None, default: False ) –

    Normalize. Defaults to False.

  • save_path (PathLike | None, default: None ) –

    Path to save plot. Defaults to None.

  • title (str | None, default: None ) –

    Title. Defaults to None.

  • width (int | None, default: None ) –

    Width. Defaults to None.

  • height (int | None, default: 400 ) –

    Height. Defaults to 400.

  • bg_color (str, default: 'rgba(38,42,50,1.0)' ) –

    Background color. Defaults to "rgba(38,42,50,1.0)".

Returns:

  • plotly.graph_objs.Figure: Plotly figure

Source code in heartkit/metrics.py
def px_plot_confusion_matrix(
    y_true: npt.NDArray,
    y_pred: npt.NDArray,
    labels: list[str],
    normalize: Literal["true", "pred", "all"] | None = False,
    save_path: os.PathLike | None = None,
    title: str | None = None,
    width: int | None = None,
    height: int | None = 400,
    bg_color: str = "rgba(38,42,50,1.0)",
):
    """Generate confusion matrix plot via plotly

    Args:
        y_true (npt.NDArray): True y labels
        y_pred (npt.NDArray): Predicted y labels
        labels (list[str]): Label names
        normalize (Literal["true", "pred", "all"] | None): Normalize. Defaults to False.
        save_path (os.PathLike | None): Path to save plot. Defaults to None.
        title (str | None): Title. Defaults to None.
        width (int | None): Width. Defaults to None.
        height (int | None): Height. Defaults to 400.
        bg_color (str): Background color. Defaults to "rgba(38,42,50,1.0)".

    Returns:
        plotly.graph_objs.Figure: Plotly figure
    """

    cm = confusion_matrix(y_true, y_pred)
    cmn = cm
    ann = None
    if normalize:
        cmn = confusion_matrix(y_true, y_pred, normalize=normalize)
        ann = np.asarray([f"{c:g}<br>{nc:.2%}" for c, nc in zip(cm.flatten(), cmn.flatten())]).reshape(cm.shape)

    cmn = pd.DataFrame(cmn, index=labels, columns=labels)
    fig = px.imshow(
        cmn,
        labels=dict(x="Predicted", y="Actual", color="Count", text_auto=True),
        title=title,
        template="plotly_dark",
        color_continuous_scale="Plotly3",
    )
    if ann is not None:
        fig.update_traces(text=ann, texttemplate="%{text}")

    fig.update_layout(
        plot_bgcolor=bg_color,
        paper_bgcolor=bg_color,
        margin=dict(l=20, r=5, t=40, b=20),
        height=height,
        width=width,
        title=title,
    )
    if save_path is not None:
        fig.write_html(save_path, include_plotlyjs="cdn", full_html=False)

    return fig

roc_auc_plot(y_true, y_prob, labels, save_path=None, **kwargs)

Generate ROC plot via matplotlib/seaborn

Parameters:

  • y_true (NDArray) –

    True y labels

  • y_prob (NDArray) –

    Predicted y labels

  • labels (list[str]) –

    Label names

  • save_path (str | None, default: None ) –

    Path to save plot. Defaults to None.

Returns:

  • tuple[Figure, Axes] | None

    tuple[plt.Figure, plt.Axes] | None: Figure and axes

Source code in heartkit/metrics.py
def roc_auc_plot(
    y_true: npt.NDArray,
    y_prob: npt.NDArray,
    labels: list[str],
    save_path: os.PathLike | None = None,
    **kwargs,
) -> tuple[plt.Figure, plt.Axes] | None:
    """Generate ROC plot via matplotlib/seaborn

    Args:
        y_true (npt.NDArray): True y labels
        y_prob (npt.NDArray): Predicted y labels
        labels (list[str]): Label names
        save_path (str | None): Path to save plot. Defaults to None.

    Returns:
        tuple[plt.Figure, plt.Axes] | None: Figure and axes
    """

    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    fig, ax = plt.subplots(figsize=kwargs.get("figsize", (10, 8)))
    label = f"ROC curve (area = {roc_auc:0.2f})"
    ax.plot(fpr, tpr, lw=2, color="darkorange", label=label)
    ax.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC-AUC")
    fig.legend(loc="lower right")
    if save_path:
        fig.savefig(save_path)
        plt.close(fig)
        return None
    return fig, ax

heartkit.utils

download_file(src, dst, progress=True, chunk_size=8192)

Download file from supplied url to destination streaming.

Parameters:

  • src (str) –

    Source URL path

  • dst (PathLike) –

    Destination file path

  • progress (bool, default: True ) –

    Display progress bar. Defaults to True.

Source code in heartkit/utils.py
def download_file(src: str, dst: os.PathLike, progress: bool = True, chunk_size: int = 8192):
    """Download file from supplied url to destination streaming.

    Args:
        src (str): Source URL path
        dst (PathLike): Destination file path
        progress (bool, optional): Display progress bar. Defaults to True.

    """
    with requests.get(src, stream=True, timeout=3600 * 24) as r:
        r.raise_for_status()
        req_len = int(r.headers.get("Content-length", 0))
        prog_bar = tqdm(total=req_len, unit="iB", unit_scale=True) if progress else None
        with open(dst, "wb") as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                f.write(chunk)
                if prog_bar:
                    prog_bar.update(len(chunk))

env_flag(env_var, default=False)

Return the specified environment variable coerced to a bool, as follows: - When the variable is unset, or set to the empty string, return default. - When the variable is set to a truthy value, returns True. These are the truthy values: - 1 - true, yes, on - When the variable is set to the anything else, returns False. Example falsy values: - 0 - no - Ignore case and leading/trailing whitespace.

Parameters:

  • env_var (str) –

    Environment variable name

  • default (bool, default: False ) –

    Default value. Defaults to False.

Returns:

  • bool ( bool ) –

    Value of environment variable

Source code in heartkit/utils.py
def env_flag(env_var: str, default: bool = False) -> bool:
    """Return the specified environment variable coerced to a bool, as follows:
    - When the variable is unset, or set to the empty string, return `default`.
    - When the variable is set to a truthy value, returns `True`.
      These are the truthy values:
          - 1
          - true, yes, on
    - When the variable is set to the anything else, returns False.
       Example falsy values:
          - 0
          - no
    - Ignore case and leading/trailing whitespace.

    Args:
        env_var (str): Environment variable name
        default (bool, optional): Default value. Defaults to False.

    Returns:
        bool: Value of environment variable
    """
    environ_string = os.environ.get(env_var, "").strip().lower()
    if not environ_string:
        return default
    return environ_string in ["1", "true", "yes", "on"]

load_pkl(file, compress=True)

Load pickled file.

Parameters:

  • file (str) –

    File path (.pkl)

  • compress (bool, default: True ) –

    If file is compressed. Defaults to True.

Returns:

  • dict[str, Any]

    dict[str, Any]: Dictionary of pickled objects

Source code in heartkit/utils.py
def load_pkl(file: str, compress: bool = True) -> dict[str, Any]:
    """Load pickled file.

    Args:
        file (str): File path (.pkl)
        compress (bool, optional): If file is compressed. Defaults to True.

    Returns:
        dict[str, Any]: Dictionary of pickled objects
    """
    if compress:
        with gzip.open(file, "rb") as fh:
            return pickle.load(fh)
    else:
        with open(file, "rb") as fh:
            return pickle.load(fh)

resolve_template_path(fpath, **kwargs)

Resolve templated path w/ supplied substitutions.

Parameters:

  • fpath (Path) –

    File path

  • **kwargs (Any, default: {} ) –

    Template arguments

Returns:

  • Path ( Path ) –

    Resolved file path

Source code in heartkit/utils.py
def resolve_template_path(fpath: Path, **kwargs: Any) -> Path:
    """Resolve templated path w/ supplied substitutions.

    Args:
        fpath (Path): File path
        **kwargs (Any): Template arguments

    Returns:
        Path: Resolved file path
    """
    return Path(Template(str(fpath)).safe_substitute(**kwargs))

save_pkl(file, compress=True, **kwargs)

Save python objects into pickle file.

Parameters:

  • file (str) –

    File path (.pkl)

  • compress (bool, default: True ) –

    Whether to compress file. Defaults to True.

Source code in heartkit/utils.py
def save_pkl(file: str, compress: bool = True, **kwargs):
    """Save python objects into pickle file.

    Args:
        file (str): File path (.pkl)
        compress (bool, optional): Whether to compress file. Defaults to True.
    """
    if compress:
        with gzip.open(file, "wb") as fh:
            pickle.dump(kwargs, fh, protocol=4)
    else:
        with open(file, "wb") as fh:
            pickle.dump(kwargs, fh, protocol=4)

set_random_seed(seed=None)

Set random seed across libraries: TF, Numpy, Python

Parameters:

  • seed (int | None, default: None ) –

    Random seed state to use. Defaults to None.

Returns:

  • int ( int ) –

    Random seed

Source code in heartkit/utils.py
def set_random_seed(seed: int | None = None) -> int:
    """Set random seed across libraries: TF, Numpy, Python

    Args:
        seed (int | None, optional): Random seed state to use. Defaults to None.

    Returns:
        int: Random seed
    """
    seed = seed or np.random.randint(2**16)
    random.seed(seed)
    np.random.seed(seed)
    try:
        import tensorflow as tf  # pylint: disable=import-outside-toplevel
    except ImportError:
        pass
    else:
        tf.random.set_seed(seed)
    return seed

setup_logger(log_name)

Setup logger with Rich

Parameters:

  • log_name (str) –

    Logger name

Returns:

  • Logger

    logging.Logger: Logger

Source code in heartkit/utils.py
def setup_logger(log_name: str) -> logging.Logger:
    """Setup logger with Rich

    Args:
        log_name (str): Logger name

    Returns:
        logging.Logger: Logger
    """
    logger = logging.getLogger(log_name)
    if logger.handlers:
        return logger
    logging.basicConfig(level=logging.ERROR, force=True, handlers=[RichHandler()])
    logger.propagate = False
    logger.setLevel(logging.INFO)
    logger.handlers = [RichHandler()]
    return logger