File size: 1,719 Bytes
4f41410
2c50826
 
 
 
 
4f41410
2c50826
 
 
5291ba9
2c50826
4f41410
 
 
 
 
 
 
 
 
 
 
 
2c50826
 
 
 
 
 
 
 
 
 
 
 
5291ba9
2c50826
 
 
 
 
 
 
4f41410
 
2c50826
 
 
 
 
 
 
 
 
 
5291ba9
2c50826
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Type

from api.baseline import BaselineAPI
from api.fireworks import FireworksAPI
from api.flux import FluxAPI
from api.pruna import PrunaAPI
from api.pruna_dev import PrunaDevAPI
from api.replicate import ReplicateAPI
from api.together import TogetherAPI
from api.fal import FalAPI
from api.aws import AWSBedrockAPI

__all__ = [
    'create_api',
    'FluxAPI',
    'BaselineAPI',
    'FireworksAPI',
    'PrunaAPI',
    'ReplicateAPI',
    'TogetherAPI',
    'FalAPI',
    'PrunaDevAPI',
]

def create_api(api_type: str) -> FluxAPI:
    """
    Factory function to create API instances.
    
    Args:
        api_type (str): The type of API to create. Must be one of:
            - "baseline"
            - "fireworks"
            - "pruna_speed_mode" (where speed_mode is the desired speed mode)
            - "replicate"
            - "together"
            - "fal"
            - "aws"
            
    Returns:
        FluxAPI: An instance of the requested API implementation
        
    Raises:
        ValueError: If an invalid API type is provided
    """
    if api_type == "pruna_dev":
        return PrunaDevAPI()
    if api_type.startswith("pruna_"):
        speed_mode = api_type[6:]  # Remove "pruna_" prefix
        return PrunaAPI(speed_mode)
        
    api_map: dict[str, Type[FluxAPI]] = {
        "baseline": BaselineAPI,
        "fireworks": FireworksAPI,
        "replicate": ReplicateAPI,
        "together": TogetherAPI,
        "fal": FalAPI,
        "aws": AWSBedrockAPI,
    }
    
    if api_type not in api_map:
        raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
    
    return api_map[api_type]()