File size: 3,860 Bytes
705389b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b915149
705389b
 
 
 
 
b915149
705389b
 
 
 
 
b915149
705389b
 
 
 
 
b915149
705389b
 
 
 
 
b915149
705389b
 
 
 
 
b915149
705389b
 
 
b915149
 
 
 
 
 
 
705389b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from dataclasses import dataclass
from typing import Dict
from manager.utils.runtime_selector import detect_runtime_environment

@dataclass
class ModelInfo:
    name: str
    size: float  
    tokens_sec: int
    type: str  
    description: str
    create_cost: int = 0
    invoke_cost: int = 0

class ModelRegistry:
    def __init__(self):
        self.env = detect_runtime_environment()
        self.models = self._build_model_registry()

    def estimate_create_cost(self, size: float, is_api: bool) -> int:
        return int(size * (10 if is_api else 5))

    def estimate_invoke_cost(self, tokens_sec: int, is_api: bool) -> int:
        base_cost = 40 if is_api else 20
        return base_cost + max(0, 60 - tokens_sec)

    def _build_model_registry(self) -> Dict[str, ModelInfo]:
        raw_models = {
            "llama3.2": {
                "size": 3,
                "tokens_sec": 30,
                "type": "local",
                "description": "3B lightweight local model"
            },
            "mistral": {
                "size": 7,
                "tokens_sec": 50,
                "type": "local",
                "description": "7B stronger local model"
            },
            "gemini-2.0-flash": {
                "size": 6,
                "tokens_sec": 170,
                "type": "api",
                "description": "Fast and efficient API model"
            },
            "gemini-2.5-pro-preview-03-25": {
                "size": 10,
                "tokens_sec": 148,
                "type": "api",
                "description": "High-reasoning API model"
            },
            "gemini-1.5-flash": {
                "size": 7,
                "tokens_sec": 190,
                "type": "api",
                "description": "Fast general-purpose model"
            },
            "gemini-2.0-flash-lite": {
                "size": 5,
                "tokens_sec": 208,
                "type": "api",
                "description": "Low-latency, cost-efficient API model"
            },
            "gemini-2.0-flash-live-001": {
                "size": 9,
                "tokens_sec": 190,
                "type": "api",
                "description": "Voice/video low-latency API model"
            }
        }

        
        models = {}
        for name, model in raw_models.items():
            is_api = model["type"] == "api"

            if is_api:
                # Flat cost for all API models
                create_cost, invoke_cost = 20, 50
            else:
                create_cost = self.estimate_create_cost(model["size"], is_api=False)
                invoke_cost = self.estimate_invoke_cost(model["tokens_sec"], is_api=False)

            models[name] = ModelInfo(
                name=name,
                size=model["size"],
                tokens_sec=model["tokens_sec"],
                type=model["type"],
                description=model["description"],
                create_cost=create_cost,
                invoke_cost=invoke_cost
            )
        return models

    def get_filtered_models(self) -> Dict[str, ModelInfo]:
        """Return only models that match the current runtime."""
        if self.env in ["gpu", "cpu-local"]:
            return {k: v for k, v in self.models.items() if v.type == "local"}
        else:
            return {k: v for k, v in self.models.items() if v.type == "api"}

    def get_all_models(self) -> Dict[str, ModelInfo]:
        """Return all models regardless of runtime."""
        return self.models


if __name__ == "__main__":
    registry = ModelRegistry()
    print(f"[INFO] Detected runtime: {registry.env}\n")

    print("Filtered models based on environment:")
    for name, model in registry.get_filtered_models().items():
        print(f"{name}: create={model.create_cost}, invoke={model.invoke_cost}, type={model.type}")