Source code for materforge.algorithms.piecewise_inverter

# SPDX-FileCopyrightText: 2025 - 2026 Rahil Miten Doshi, Friedrich-Alexander-Universität Erlangen-Nürnberg
# SPDX-FileCopyrightText: 2026 Matthias Markl, Friedrich-Alexander-Universität Erlangen-Nürnberg
# SPDX-License-Identifier: BSD-3-Clause

import logging
from typing import List, Optional, Union
import sympy as sp

logger = logging.getLogger(__name__)


[docs] class PiecewiseInverter: """Creates inverse functions for linear piecewise functions (degree <= 1 only).""" def __init__(self, tolerance: float = 1e-12) -> None: self.tolerance = tolerance
[docs] @staticmethod def create_inverse(piecewise_func: Union[sp.Piecewise, sp.Expr], input_symbol: Union[sp.Symbol, sp.Basic], output_symbol: Union[sp.Symbol, sp.Basic], tolerance: float = 1e-12) -> sp.Piecewise: """Creates the inverse of a linear piecewise function. Args: piecewise_func: Original piecewise function f(input_symbol). input_symbol: Independent variable of the original function. output_symbol: Symbol for the inverse function's argument. tolerance: Numerical tolerance for inversion stability. Returns: Inverse piecewise function expressed in terms of output_symbol. Raises: ValueError: If any piece has degree > 1 or the function is non-monotonic. Example: >>> dep = sp.Symbol("T") >>> E = sp.Symbol("E") >>> pw = sp.Piecewise((2*dep + 100, dep < 500), (3*dep - 400, True)) >>> inv = PiecewiseInverter.create_inverse(pw, dep, E) """ logger.info("Creating inverse function: %s = f_inv(%s)", input_symbol, output_symbol) if not isinstance(piecewise_func, sp.Piecewise): raise ValueError(f"Expected Piecewise function, got {type(piecewise_func).__name__}") inverter = PiecewiseInverter(tolerance) return inverter._create_inverse_impl(piecewise_func, input_symbol, output_symbol)
# --- Internal implementation --- def _create_inverse_impl(self, piecewise_func: sp.Piecewise, input_symbol: Union[sp.Symbol, sp.Basic], output_symbol: Union[sp.Symbol, sp.Basic]) -> sp.Piecewise: """Core inversion logic.""" logger.info("Inverting piecewise with %d pieces", len(piecewise_func.args)) self._validate_linear_only(piecewise_func, input_symbol) piece_boundaries: List[dict] = [] for i, (expr, condition) in enumerate(piecewise_func.args): if condition is sp.true: last_boundary = piece_boundaries[-1]["boundary_val"] if piece_boundaries else None piece_boundaries.append({ "index": i, "expr": expr, "boundary_val": last_boundary, "output_bound": float("inf"), "is_final": True, "slope": None, }) else: try: boundary_val = self._extract_boundary(condition, input_symbol) output_bound = float(expr.subs(input_symbol, boundary_val)) slope = None if sp.degree(expr, input_symbol) == 1: slope = float(sp.Poly(expr, input_symbol).all_coeffs()[0]) piece_boundaries.append({ "index": i, "expr": expr, "boundary_val": boundary_val, "output_bound": output_bound, "is_final": False, "slope": slope, }) except Exception as e: raise ValueError(f"Error processing piece {i + 1}: {str(e)}") from e self._validate_monotonicity(piece_boundaries, input_symbol) inverse_conditions = [] for i, piece in enumerate(piece_boundaries): inv_expr = self._invert_linear_expression( piece["expr"], input_symbol, output_symbol, piece["boundary_val"]) if piece["is_final"]: inverse_conditions.append((inv_expr, True)) else: slope = piece["slope"] ob = piece["output_bound"] cond = (output_symbol > ob) if (slope is not None and slope < 0) else (output_symbol < ob) inverse_conditions.append((inv_expr, cond)) result = sp.Piecewise(*inverse_conditions) logger.info("Created inverse with %d conditions", len(inverse_conditions)) return result def _validate_monotonicity(self, piece_boundaries: List[dict], input_symbol: sp.Symbol) -> None: """Raises ValueError if the piecewise function is non-monotonic.""" non_final = [p for p in piece_boundaries if not p["is_final"]] if len(non_final) < 2: return slopes = [] for piece in non_final: deg = sp.degree(piece["expr"], input_symbol) if deg == 0: slopes.append(0.0) elif deg == 1: slopes.append(float(sp.Poly(piece["expr"], input_symbol).all_coeffs()[0])) positive = [s for s in slopes if s > self.tolerance] negative = [s for s in slopes if s < -self.tolerance] if positive and negative: raise ValueError("Piecewise function is not monotonic - mix of increasing and decreasing pieces.") logger.debug("Monotonicity check passed. Slopes: %s", slopes) @staticmethod def _validate_linear_only(piecewise_func: sp.Piecewise, input_symbol: sp.Symbol) -> None: """Raises ValueError if any piece has degree > 1.""" for i, (expr, _) in enumerate(piecewise_func.args): try: deg = sp.degree(expr, input_symbol) if deg > 1: raise ValueError( f"Piece {i + 1} has degree {deg}. " f"Only linear functions (degree <= 1) are supported.") except Exception as e: raise ValueError(f"Error validating piece {i + 1}: {str(e)}") from e @staticmethod def _extract_boundary(condition: sp.Basic, symbol: sp.Symbol) -> float: """Extracts the numeric boundary value from a condition such as dep < 300.0. Args: condition: A SymPy relational expression. symbol: The independent variable symbol. Returns: Boundary value as a Python float. Raises: ValueError: If the boundary cannot be extracted. """ try: if hasattr(condition, "rhs"): return float(condition.rhs) if hasattr(condition, "args") and len(condition.args) == 2: lhs, rhs = condition.args if lhs == symbol: return float(rhs) if rhs == symbol: return float(lhs) raise ValueError(f"Cannot extract boundary from condition: {condition}") except (ValueError, TypeError) as e: raise ValueError(f"Error extracting boundary from {condition}: {str(e)}") from e def _invert_linear_expression(self, expr: sp.Expr, input_symbol: sp.Symbol, output_symbol: sp.Symbol, boundary_val: Optional[float] = None) -> sp.Expr: """Inverts a linear expression: a*x + b = y => x = (y - b) / a. For constant pieces (degree 0), returns boundary_val as the inverse. Args: expr: Linear expression to invert. input_symbol: Independent variable (x). output_symbol: Output variable (y). boundary_val: Boundary of the piece (used for constant pieces). Returns: Inverted SymPy expression. Raises: ValueError: If the linear coefficient is near-zero or degree > 1. """ try: deg = sp.degree(expr, input_symbol) if deg == 0: if boundary_val is not None and boundary_val != float("inf"): return sp.sympify(boundary_val) return sp.sympify(float(expr)) if deg == 1: coeffs = sp.Poly(expr, input_symbol).all_coeffs() a, b = float(coeffs[0]), float(coeffs[1]) if abs(a) < self.tolerance: raise ValueError(f"Linear coefficient {a:.2e} is too small for stable inversion " f"(tolerance={self.tolerance:.2e})") return (output_symbol - b) / a raise ValueError( f"Expression has degree {deg}; only linear expressions are supported") except Exception as e: raise ValueError(f"Failed to invert expression {expr}: {str(e)}") from e