Skip to content

shape_propagation

Classes

ShapeInferenceResult dataclass

ShapeInferenceResult(changed_tensor_ids: set[str] = set(), unresolved_tensor_ids: set[str] = set(), diagnostics: list[str] = list())

Shape propagation outcome.

Attributes:

  • changed_tensor_ids (set[str]) –

    Tensor IDs whose shapes changed during propagation.

  • unresolved_tensor_ids (set[str]) –

    Dynamic tensor IDs that remain unresolved.

  • diagnostics (list[str]) –

    Human-readable diagnostics describing inference updates and unresolved tensors. Skip diagnostics (rule/fallback skipped) are only included when DEBUG logging is enabled to avoid unbounded growth on large graphs.

Functions

propagate_shapes

propagate_shapes(model: AirModel, *, strict: bool = False) -> ShapeInferenceResult

Infer and propagate tensor shapes over the graph until fixed-point.

Parameters:

  • model

    (AirModel) –

    AIR model to update in place.

  • strict

    (bool, default: False ) –

    If True, raise when unresolved dynamic tensors remain.

Returns:

Raises:

  • ValueError

    If strict is True and unresolved dynamic tensors remain.