Skip to content

train

Classes

Functions

train

train(params: HKTaskParams)

Train model for foundation task using SimCLR

Parameters:

Source code in heartkit/tasks/foundation/train.py
def train(params: HKTaskParams):
    """Train model for foundation task using SimCLR

    Args:
        params (HKTaskParams): Task parameters
    """
    os.makedirs(params.job_dir, exist_ok=True)
    logger = helia.utils.setup_logger(__name__, level=params.verbose, file_path=params.job_dir / "train.log")
    logger.debug(f"Creating working directory in {params.job_dir}")

    params.temperature = float(getattr(params, "temperature", 0.1))

    params.seed = helia.utils.set_random_seed(params.seed)
    logger.debug(f"Random seed {params.seed}")

    with open(params.job_dir / "configuration.json", "w", encoding="utf-8") as fp:
        fp.write(params.model_dump_json(indent=2))

    if helia.utils.env_flag("WANDB"):
        wandb.init(project=params.project, entity="ambiq", dir=params.job_dir)
        wandb.config.update(params.model_dump())
    # END IF

    feat_shape = (params.frame_size, 1)

    datasets = [DatasetFactory.get(ds.name)(**ds.params) for ds in params.datasets]

    train_ds, val_ds = load_train_datasets(datasets=datasets, params=params)

    # Save validation data
    if params.val_file:
        logger.info(f"Saving validation dataset to {params.val_file}")
        os.makedirs(params.val_file, exist_ok=True)
        val_ds.save(str(params.val_file))

    # Create encoder
    encoder_input = keras.Input(shape=feat_shape, dtype="float32")
    encoder = ModelFactory.get(params.architecture.name)(
        inputs=encoder_input,
        params=params.architecture.params,
        num_classes=None,
    )

    flops = helia.metrics.flops.get_flops(encoder, batch_size=1, fpath=params.job_dir / "encoder_flops.log")
    encoder.summary(print_fn=logger.info)
    logger.debug(f"Encoder requires {flops / 1e6:0.2f} MFLOPS")

    # Create  projector
    # encoder_output = encoder(encoder_input)
    # projection_width = params.num_classes
    # projector_input = encoder_output
    # projector_output = keras.layers.Dense(projection_width, activation="relu6")(projector_input)
    # projector_output = keras.layers.Dense(projection_width)(projector_output)
    # projector = keras.Model(inputs=projector_input, outputs=projector_output, name="projector")
    # flops = helia.metrics.flops.get_flops(projector, batch_size=1, fpath=params.job_dir / "projector_flops.log")
    # projector.summary(print_fn=logger.info)
    # logger.debug(f"Projector requires {flops/1e6:0.2f} MFLOPS")

    if params.model_file is None:
        params.model_file = params.job_dir / "model.keras"

    model = helia.trainers.SimCLRTrainer(
        encoder=encoder,
        projector=None,
    )

    def get_scheduler():
        t_mul = 1
        first_steps = (params.steps_per_epoch * params.epochs) / (np.power(params.lr_cycles, t_mul) - t_mul + 1)
        scheduler = keras.optimizers.schedules.CosineDecayRestarts(
            initial_learning_rate=params.lr_rate,
            first_decay_steps=np.ceil(first_steps),
            t_mul=t_mul,
            m_mul=0.5,
        )
        return scheduler

    model.compile(
        encoder_optimizer=keras.optimizers.Adam(get_scheduler()),
        encoder_loss=helia.losses.simclr.SimCLRLoss(temperature=params.temperature),
        encoder_metrics=[keras.metrics.MeanSquaredError(name="mse"), keras.metrics.CosineSimilarity(name="cos")],
    )

    ModelCheckpoint = keras.callbacks.ModelCheckpoint
    if helia.utils.env_flag("WANDB"):
        ModelCheckpoint = WandbModelCheckpoint

    model_callbacks = [
        keras.callbacks.EarlyStopping(
            monitor=f"val_{params.val_metric}",
            patience=max(int(0.25 * params.epochs), 1),
            mode="max" if params.val_metric == "f1" else "auto",
            restore_best_weights=True,
            verbose=max(0, params.verbose - 1),
        ),
        ModelCheckpoint(
            filepath=str(params.model_file),
            monitor=f"val_{params.val_metric}",
            save_best_only=True,
            mode="max" if params.val_metric == "f1" else "auto",
            verbose=max(0, params.verbose - 1),
        ),
        keras.callbacks.CSVLogger(params.job_dir / "history.csv"),
    ]
    if helia.utils.env_flag("TENSORBOARD"):
        model_callbacks.append(
            keras.callbacks.TensorBoard(
                log_dir=params.job_dir,
                write_steps_per_second=True,
            )
        )
    if helia.utils.env_flag("WANDB"):
        model_callbacks.append(WandbMetricsLogger())
    # Use minimal progress bar
    if params.verbose <= 1:
        model_callbacks.append(
            helia.callbacks.TQDMProgressBar(
                show_epoch_progress=False,
            )
        )
    try:
        history = model.fit(
            train_ds,
            steps_per_epoch=params.steps_per_epoch,
            verbose=max(0, params.verbose - 1),
            epochs=params.epochs,
            validation_data=val_ds,
            callbacks=model_callbacks,
        )
    except KeyboardInterrupt:
        logger.warning("Stopping training due to keyboard interrupt")

    logger.debug(f"Model saved to {params.model_file}")

    setup_plotting()
    if history:
        helia.plotting.plot_history_metrics(
            history.history,
            metrics=["loss", "cos"],
            save_path=params.job_dir / "history.png",
            title="Training History",
            stack=True,
            figsize=(9, 5),
        )

    metrics = model.evaluate(val_ds, verbose=2, return_dict=True)
    logger.info("[VAL SET] " + ", ".join(f"{k.upper()}: {v:.4f}" for k, v in metrics.items()))

    # cleanup
    keras.utils.clear_session()
    for ds in datasets:
        ds.close()