from typing import Type from api.baseline import BaselineAPI from api.fireworks import FireworksAPI from api.flux import FluxAPI from api.pruna import PrunaAPI from api.pruna_dev import PrunaDevAPI from api.replicate import ReplicateAPI from api.together import TogetherAPI from api.fal import FalAPI from api.aws import AWSBedrockAPI __all__ = [ 'create_api', 'FluxAPI', 'BaselineAPI', 'FireworksAPI', 'PrunaAPI', 'ReplicateAPI', 'TogetherAPI', 'FalAPI', 'PrunaDevAPI', ] def create_api(api_type: str) -> FluxAPI: """ Factory function to create API instances. Args: api_type (str): The type of API to create. Must be one of: - "baseline" - "fireworks" - "pruna_speed_mode" (where speed_mode is the desired speed mode) - "replicate" - "together" - "fal" - "aws" Returns: FluxAPI: An instance of the requested API implementation Raises: ValueError: If an invalid API type is provided """ if api_type == "pruna_dev": return PrunaDevAPI() if api_type.startswith("pruna_"): speed_mode = api_type[6:] # Remove "pruna_" prefix return PrunaAPI(speed_mode) api_map: dict[str, Type[FluxAPI]] = { "baseline": BaselineAPI, "fireworks": FireworksAPI, "replicate": ReplicateAPI, "together": TogetherAPI, "fal": FalAPI, "aws": AWSBedrockAPI, } if api_type not in api_map: raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'") return api_map[api_type]()