|
from dataclasses import dataclass |
|
from typing import Dict |
|
from manager.utils.runtime_selector import detect_runtime_environment |
|
|
|
@dataclass |
|
class ModelInfo: |
|
name: str |
|
size: float |
|
tokens_sec: int |
|
type: str |
|
description: str |
|
create_cost: int = 0 |
|
invoke_cost: int = 0 |
|
|
|
class ModelRegistry: |
|
def __init__(self): |
|
self.env = detect_runtime_environment() |
|
self.models = self._build_model_registry() |
|
|
|
def estimate_create_cost(self, size: float, is_api: bool) -> int: |
|
return int(size * (10 if is_api else 5)) |
|
|
|
def estimate_invoke_cost(self, tokens_sec: int, is_api: bool) -> int: |
|
base_cost = 40 if is_api else 20 |
|
return base_cost + max(0, 60 - tokens_sec) |
|
|
|
def _build_model_registry(self) -> Dict[str, ModelInfo]: |
|
raw_models = { |
|
"llama3.2": { |
|
"size": 3, |
|
"tokens_sec": 30, |
|
"type": "local", |
|
"description": "3B lightweight local model" |
|
}, |
|
"mistral": { |
|
"size": 7, |
|
"tokens_sec": 50, |
|
"type": "local", |
|
"description": "7B stronger local model" |
|
}, |
|
"gemini-2.0-flash": { |
|
"size": 6, |
|
"tokens_sec": 170, |
|
"type": "api", |
|
"description": "Fast and efficient API model" |
|
}, |
|
"gemini-2.5-pro-preview-03-25": { |
|
"size": 10, |
|
"tokens_sec": 148, |
|
"type": "api", |
|
"description": "High-reasoning API model" |
|
}, |
|
"gemini-1.5-flash": { |
|
"size": 7, |
|
"tokens_sec": 190, |
|
"type": "api", |
|
"description": "Fast general-purpose model" |
|
}, |
|
"gemini-2.0-flash-lite": { |
|
"size": 5, |
|
"tokens_sec": 208, |
|
"type": "api", |
|
"description": "Low-latency, cost-efficient API model" |
|
}, |
|
"gemini-2.0-flash-live-001": { |
|
"size": 9, |
|
"tokens_sec": 190, |
|
"type": "api", |
|
"description": "Voice/video low-latency API model" |
|
} |
|
} |
|
|
|
|
|
models = {} |
|
for name, model in raw_models.items(): |
|
is_api = model["type"] == "api" |
|
|
|
if is_api: |
|
|
|
create_cost, invoke_cost = 20, 50 |
|
else: |
|
create_cost = self.estimate_create_cost(model["size"], is_api=False) |
|
invoke_cost = self.estimate_invoke_cost(model["tokens_sec"], is_api=False) |
|
|
|
models[name] = ModelInfo( |
|
name=name, |
|
size=model["size"], |
|
tokens_sec=model["tokens_sec"], |
|
type=model["type"], |
|
description=model["description"], |
|
create_cost=create_cost, |
|
invoke_cost=invoke_cost |
|
) |
|
return models |
|
|
|
def get_filtered_models(self) -> Dict[str, ModelInfo]: |
|
"""Return only models that match the current runtime.""" |
|
if self.env in ["gpu", "cpu-local"]: |
|
return {k: v for k, v in self.models.items() if v.type == "local"} |
|
else: |
|
return {k: v for k, v in self.models.items() if v.type == "api"} |
|
|
|
def get_all_models(self) -> Dict[str, ModelInfo]: |
|
"""Return all models regardless of runtime.""" |
|
return self.models |
|
|
|
|
|
if __name__ == "__main__": |
|
registry = ModelRegistry() |
|
print(f"[INFO] Detected runtime: {registry.env}\n") |
|
|
|
print("Filtered models based on environment:") |
|
for name, model in registry.get_filtered_models().items(): |
|
print(f"{name}: create={model.create_cost}, invoke={model.invoke_cost}, type={model.type}") |
|
|