Source code for swarmauri.experimental.chains.TypeAgnosticCallableChain

from typing import Any, Callable, List, Dict, Optional, Tuple, Union

CallableDefinition = Tuple[Callable, List[Any], Dict[str, Any], Union[str, Callable, None]]

[docs] class TypeAgnosticCallableChain: def __init__(self, callables: Optional[List[CallableDefinition]] = None): self.callables = callables if callables is not None else [] @staticmethod def _ignore_previous(_previous_result, *args, **kwargs): return args, kwargs @staticmethod def _use_first_arg(previous_result, *args, **kwargs): return [previous_result] + list(args), kwargs @staticmethod def _use_all_previous_args_first(previous_result, *args, **kwargs): if not isinstance(previous_result, (list, tuple)): previous_result = [previous_result] return list(previous_result) + list(args), kwargs @staticmethod def _use_all_previous_args_only(previous_result, *_args, **_kwargs): if not isinstance(previous_result, (list, tuple)): previous_result = [previous_result] return list(previous_result), {} @staticmethod def _add_previous_kwargs_overwrite(previous_result, args, kwargs): if not isinstance(previous_result, dict): raise ValueError("Previous result is not a dictionary.") return args, {**kwargs, **previous_result} @staticmethod def _add_previous_kwargs_no_overwrite(previous_result, args, kwargs): if not isinstance(previous_result, dict): raise ValueError("Previous result is not a dictionary.") return args, {**previous_result, **kwargs} @staticmethod def _use_all_args_all_kwargs_overwrite(previous_result_args, previous_result_kwargs, *args, **kwargs): combined_args = list(previous_result_args) + list(args) if isinstance(previous_result_args, (list, tuple)) else list(args) combined_kwargs = previous_result_kwargs if isinstance(previous_result_kwargs, dict) else {} combined_kwargs.update(kwargs) return combined_args, combined_kwargs @staticmethod def _use_all_args_all_kwargs_no_overwrite(previous_result_args, previous_result_kwargs, *args, **kwargs): combined_args = list(previous_result_args) + list(args) if isinstance(previous_result_args, (list, tuple)) else list(args) combined_kwargs = kwargs if isinstance(kwargs, dict) else {} combined_kwargs = {**combined_kwargs, **(previous_result_kwargs if isinstance(previous_result_kwargs, dict) else {})} return combined_args, combined_kwargs
[docs] def add_callable(self, func: Callable, args: List[Any] = None, kwargs: Dict[str, Any] = None, input_handler: Union[str, Callable, None] = None) -> None: if isinstance(input_handler, str): # Map the string to the corresponding static method input_handler_method = getattr(self, f"_{input_handler}", None) if input_handler_method is None: raise ValueError(f"Unknown input handler name: {input_handler}") input_handler = input_handler_method elif input_handler is None: input_handler = self._ignore_previous self.callables.append((func, args or [], kwargs or {}, input_handler))
def __call__(self, *initial_args, **initial_kwargs) -> Any: result = None for func, args, kwargs, input_handler in self.callables: if isinstance(input_handler, str): # Map the string to the corresponding static method input_handler_method = getattr(self, f"_{input_handler}", None) if input_handler_method is None: raise ValueError(f"Unknown input handler name: {input_handler}") input_handler = input_handler_method elif input_handler is None: input_handler = self._ignore_previous args, kwargs = input_handler(result, *args, **kwargs) if result is not None else (args, kwargs) result = func(*args, **kwargs) return result def __or__(self, other: "TypeAgnosticCallableChain") -> "TypeAgnosticCallableChain": if not isinstance(other, TypeAgnosticCallableChain): raise TypeError("Operand must be an instance of TypeAgnosticCallableChain") new_chain = TypeAgnosticCallableChain(self.callables + other.callables) return new_chain