Spaces:
Running
Running
File size: 1,719 Bytes
4f41410 2c50826 4f41410 2c50826 5291ba9 2c50826 4f41410 2c50826 5291ba9 2c50826 4f41410 2c50826 5291ba9 2c50826 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from typing import Type
from api.baseline import BaselineAPI
from api.fireworks import FireworksAPI
from api.flux import FluxAPI
from api.pruna import PrunaAPI
from api.pruna_dev import PrunaDevAPI
from api.replicate import ReplicateAPI
from api.together import TogetherAPI
from api.fal import FalAPI
from api.aws import AWSBedrockAPI
__all__ = [
'create_api',
'FluxAPI',
'BaselineAPI',
'FireworksAPI',
'PrunaAPI',
'ReplicateAPI',
'TogetherAPI',
'FalAPI',
'PrunaDevAPI',
]
def create_api(api_type: str) -> FluxAPI:
"""
Factory function to create API instances.
Args:
api_type (str): The type of API to create. Must be one of:
- "baseline"
- "fireworks"
- "pruna_speed_mode" (where speed_mode is the desired speed mode)
- "replicate"
- "together"
- "fal"
- "aws"
Returns:
FluxAPI: An instance of the requested API implementation
Raises:
ValueError: If an invalid API type is provided
"""
if api_type == "pruna_dev":
return PrunaDevAPI()
if api_type.startswith("pruna_"):
speed_mode = api_type[6:] # Remove "pruna_" prefix
return PrunaAPI(speed_mode)
api_map: dict[str, Type[FluxAPI]] = {
"baseline": BaselineAPI,
"fireworks": FireworksAPI,
"replicate": ReplicateAPI,
"together": TogetherAPI,
"fal": FalAPI,
"aws": AWSBedrockAPI,
}
if api_type not in api_map:
raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
return api_map[api_type]()
|