Source code for swarmauri.core.ComponentBase

from typing import (
    Optional, 
    List,
    Literal, 
    TypeVar, 
    Type, 
    Union, 
    Annotated, 
    Generic, 
    ClassVar, 
    Set,
    get_args)

from uuid import uuid4
from enum import Enum
import inspect
import hashlib
from pydantic import BaseModel, Field, field_validator
import logging
from swarmauri.core.typing import SubclassUnion

[docs] class ResourceTypes(Enum): UNIVERSAL_BASE = 'ComponentBase' AGENT = 'Agent' AGENT_FACTORY = 'AgentFactory' CHAIN = 'Chain' CHAIN_METHOD = 'ChainMethod' CHUNKER = 'Chunker' CONVERSATION = 'Conversation' DISTANCE = 'Distance' DOCUMENT_STORE = 'DocumentStore' DOCUMENT = 'Document' EMBEDDING = 'Embedding' EXCEPTION = 'Exception' LLM = 'LLM' MESSAGE = 'Message' METRIC = 'Metric' PARSER = 'Parser' PROMPT = 'Prompt' STATE = 'State' CHAINSTEP = 'ChainStep' SCHEMA_CONVERTER = 'SchemaConverter' SWARM = 'Swarm' TOOLKIT = 'Toolkit' TOOL = 'Tool' PARAMETER = 'Parameter' TRACE = 'Trace' UTIL = 'Util' VECTOR_STORE = 'VectorStore' VECTOR = 'Vector'
[docs] def generate_id() -> str: return str(uuid4())
[docs] class ComponentBase(BaseModel): name: Optional[str] = None id: str = Field(default_factory=generate_id) members: List[str] = Field(default_factory=list) owner: Optional[str] = None host: Optional[str] = None resource: str = Field(default="ComponentBase") version: str = "0.1.0" __swm_subclasses__: ClassVar[Set[Type['ComponentBase']]] = set() type: Literal['ComponentBase'] = 'ComponentBase' def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) ComponentBase.__swm_register_subclass__(cls) # @classmethod # def __swm__get_subclasses__(cls) -> set: # logging.debug('__swm__get_subclasses__ executed\n') # def is_excluded_module(module_name: str) -> bool: # return (module_name == 'builtins' or # module_name == 'types') # subclasses_dict = {cls.__name__: cls} # for subclass in cls.__subclasses__(): # if not is_excluded_module(subclass.__module__): # subclasses_dict.update({_s.__name__: _s for _s in subclass.__swm__get_subclasses__() # if not is_excluded_module(_s.__module__)}) # return set(subclasses_dict.values()) @classmethod def __swm_register_subclass__(cls, subclass): logging.debug('__swm_register_subclass__ executed\n') if 'type' in subclass.__annotations__: sub_type = subclass.__annotations__['type'] if sub_type not in [subclass.__annotations__['type'] for subclass in cls.__swm_subclasses__]: cls.__swm_subclasses__.add(subclass) else: logging.warning(f'Subclass {subclass.__name__} does not have a type annotation') # [subclass.__swm_reset_class__() for subclass in cls.__swm_subclasses__ # if hasattr(subclass, '__swm_reset_class__')] @classmethod def __swm_reset_class__(cls): logging.debug('__swm_reset_class__ executed\n') for each in cls.__fields__: logging.debug(each, cls.__fields__[each].discriminator) if (cls.__fields__[each].discriminator and each in cls.__annotations__ ): if len(get_args(cls.__fields__[each].annotation)) > 0: for x in range(0, len(get_args(cls.__fields__[each].annotation))): if hasattr(get_args(cls.__fields__[each].annotation)[x], '__base__'): if (hasattr(get_args(cls.__fields__[each].annotation)[x].__base__, '__swm_subclasses__') and not get_args(cls.__fields__[each].annotation)[x].__base__.__name__ == 'ComponentBase'): baseclass = get_args(cls.__fields__[each].annotation)[x].__base__ sc = SubclassUnion[baseclass] cls.__annotations__[each] = sc cls.__fields__[each].annotation = sc # This is not necessary as the model_rebuild address forward_refs # https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_post_init # cls.update_forward_refs() cls.model_rebuild(force=True)
[docs] @field_validator('type') def set_type(cls, v, values): if v == 'ComponentBase' and cls.__name__ != 'ComponentBase': return cls.__name__ return v
def __swm_class_hash__(self): sig_hash = hashlib.sha256() for attr_name in dir(self): attr_value = getattr(self, attr_name) if callable(attr_value) and not attr_name.startswith("_"): sig = inspect.signature(attr_value) sig_hash.update(str(sig).encode()) return sig_hash.hexdigest()
[docs] @classmethod def swm_public_interfaces(cls): methods = [] for attr_name in dir(cls): attr_value = getattr(cls, attr_name) if (callable(attr_value) and not attr_name.startswith("_")) or isinstance(attr_value, property): methods.append(attr_name) return methods
[docs] @classmethod def swm_ismethod_registered(cls, method_name: str): return method_name in cls.public_interfaces()
[docs] @classmethod def swm_method_signature(cls, input_signature): for method_name in cls.public_interfaces(): method = getattr(cls, method_name) if callable(method): sig = str(inspect.signature(method)) if sig == input_signature: return True return False
@property def swm_path(self): if self.host and self.owner: return f"{self.host}/{self.owner}/{self.resource}/{self.name}/{self.id}" if self.resource and self.name: return f"/{self.resource}/{self.name}/{self.id}" return f"/{self.resource}/{self.id}" @property def swm_is_remote(self): return bool(self.host)