Bring-Your-Own-Model (BYOM)
The model factory can be extended to include custom models. This is useful when you have a custom model architecture that you would like to use for training. The custom model can be registered with the model factory by defining a custom model function and registering it with the sk.ModelFactory
.
How it Works
-
Create a Model: Define a new model function that takes a keras.Input
, model parameters, and number of classes as arguments and returns a keras.Model
.
| import keras
import sleepkit as sk
def custom_model_from_params(
x: keras.KerasTensor,
params: dict,
num_classes: int | None = None,
) -> keras.Model:
y = x
# Create fully connected network from params
for layer in params["layers"]:
y = keras.layers.Dense(layer["units"], activation=layer["activation"])(y)
if num_classes:
y = keras.layers.Dense(num_classes, activation="softmax")(y)
return keras.Model(inputs=x, outputs=y)
|
-
Register the Model: Register the new model function with the sk.ModelFactory
by calling the register
method. This method takes the model name and the callable as arguments.
| sk.ModelFactory.register("custom-model", custom_model_from_params)
|
-
Use the Model: The new model can now be used with the sk.ModelFactory
to perform various operations such as downloading and generating data.
| inputs = keras.Input(shape=(100,))
model = sk.ModelFactory.get("custom-model")(
inputs=inputs,
params={
"layers": [
{"units": 64, "activation": "relu"},
{"units": 32, "activation": "relu"},
]
},
num_classes=5,
)
model.summary()
|
Better Model Params
Rather than using a dictionary to define the model parameters, you can define a custom dataclass or Pydantic model to enforce type checking and provide better documentation.
| from pydantic import BaseModel
class CustomLayerParams(BaseModel):
units: int
activation: str
class CustomModelParams(BaseModel):
layers: list[CustomLayerParams]
def custom_model_from_params(
x: keras.KerasTensor,
params: dict,
num_classes: int | None = None,
) -> keras.Model:
# Convert and validate params
params = CustomModelParams(**params)
y = x
# Create fully connected network from params
for layer in params.layers:
y = keras.layers.Dense(layer.units, activation=layer.activation)(y)
if num_classes:
y = keras.layers.Dense(num_classes, activation="softmax")(y)
return keras.Model(inputs=x, outputs=y)
|