shape_propagation
Classes
ShapeInferenceResult
dataclass
ShapeInferenceResult(changed_tensor_ids: set[str] = set(), unresolved_tensor_ids: set[str] = set(), diagnostics: list[str] = list())
Shape propagation outcome.
Attributes:
-
changed_tensor_ids(set[str]) –Tensor IDs whose shapes changed during propagation.
-
unresolved_tensor_ids(set[str]) –Dynamic tensor IDs that remain unresolved.
-
diagnostics(list[str]) –Human-readable diagnostics describing inference updates and unresolved tensors. Skip diagnostics (rule/fallback skipped) are only included when DEBUG logging is enabled to avoid unbounded growth on large graphs.
Functions
propagate_shapes
Infer and propagate tensor shapes over the graph until fixed-point.
Parameters:
-
(modelAirModel) –AIR model to update in place.
-
(strictbool, default:False) –If True, raise when unresolved dynamic tensors remain.
Returns:
-
ShapeInferenceResult–Shape inference result with updated tensors and diagnostics.
Raises:
-
ValueError–If
strictis True and unresolved dynamic tensors remain.