Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import Dict | |
from manager.utils.runtime_selector import detect_runtime_environment | |
class ModelInfo: | |
name: str | |
size: float | |
tokens_sec: int | |
type: str | |
description: str | |
create_cost: int = 0 | |
invoke_cost: int = 0 | |
class ModelRegistry: | |
def __init__(self): | |
self.env = detect_runtime_environment() | |
self.models = self._build_model_registry() | |
def estimate_create_cost(self, size: float, is_api: bool) -> int: | |
return int(size * (10 if is_api else 5)) | |
def estimate_invoke_cost(self, tokens_sec: int, is_api: bool) -> int: | |
base_cost = 40 if is_api else 20 | |
return base_cost + max(0, 60 - tokens_sec) | |
def _build_model_registry(self) -> Dict[str, ModelInfo]: | |
raw_models = { | |
"llama3.2": { | |
"size": 3, | |
"tokens_sec": 30, | |
"type": "local", | |
"description": "3B lightweight local model" | |
}, | |
"mistral": { | |
"size": 7, | |
"tokens_sec": 50, | |
"type": "local", | |
"description": "7B stronger local model" | |
}, | |
"gemini-2.0-flash": { | |
"size": 6, | |
"tokens_sec": 170, | |
"type": "api", | |
"description": "Fast and efficient API model" | |
}, | |
"gemini-2.5-pro-preview-03-25": { | |
"size": 10, | |
"tokens_sec": 148, | |
"type": "api", | |
"description": "High-reasoning API model" | |
}, | |
"gemini-1.5-flash": { | |
"size": 7, | |
"tokens_sec": 190, | |
"type": "api", | |
"description": "Fast general-purpose model" | |
}, | |
"gemini-2.0-flash-lite": { | |
"size": 5, | |
"tokens_sec": 208, | |
"type": "api", | |
"description": "Low-latency, cost-efficient API model" | |
}, | |
"gemini-2.0-flash-live-001": { | |
"size": 9, | |
"tokens_sec": 190, | |
"type": "api", | |
"description": "Voice/video low-latency API model" | |
} | |
} | |
models = {} | |
for name, model in raw_models.items(): | |
is_api = model["type"] == "api" | |
if is_api: | |
# Flat cost for all API models | |
create_cost, invoke_cost = 20, 50 | |
else: | |
create_cost = self.estimate_create_cost(model["size"], is_api=False) | |
invoke_cost = self.estimate_invoke_cost(model["tokens_sec"], is_api=False) | |
models[name] = ModelInfo( | |
name=name, | |
size=model["size"], | |
tokens_sec=model["tokens_sec"], | |
type=model["type"], | |
description=model["description"], | |
create_cost=create_cost, | |
invoke_cost=invoke_cost | |
) | |
return models | |
def get_filtered_models(self) -> Dict[str, ModelInfo]: | |
"""Return only models that match the current runtime.""" | |
if self.env in ["gpu", "cpu-local"]: | |
return {k: v for k, v in self.models.items() if v.type == "local"} | |
else: | |
return {k: v for k, v in self.models.items() if v.type == "api"} | |
def get_all_models(self) -> Dict[str, ModelInfo]: | |
"""Return all models regardless of runtime.""" | |
return self.models | |
if __name__ == "__main__": | |
registry = ModelRegistry() | |
print(f"[INFO] Detected runtime: {registry.env}\n") | |
print("Filtered models based on environment:") | |
for name, model in registry.get_filtered_models().items(): | |
print(f"{name}: create={model.create_cost}, invoke={model.invoke_cost}, type={model.type}") | |