hashiruAI / src /models_cost.py
saisha09's picture
models_cost file updated
b915149
raw
history blame
3.86 kB
from dataclasses import dataclass
from typing import Dict
from manager.utils.runtime_selector import detect_runtime_environment
@dataclass
class ModelInfo:
name: str
size: float
tokens_sec: int
type: str
description: str
create_cost: int = 0
invoke_cost: int = 0
class ModelRegistry:
def __init__(self):
self.env = detect_runtime_environment()
self.models = self._build_model_registry()
def estimate_create_cost(self, size: float, is_api: bool) -> int:
return int(size * (10 if is_api else 5))
def estimate_invoke_cost(self, tokens_sec: int, is_api: bool) -> int:
base_cost = 40 if is_api else 20
return base_cost + max(0, 60 - tokens_sec)
def _build_model_registry(self) -> Dict[str, ModelInfo]:
raw_models = {
"llama3.2": {
"size": 3,
"tokens_sec": 30,
"type": "local",
"description": "3B lightweight local model"
},
"mistral": {
"size": 7,
"tokens_sec": 50,
"type": "local",
"description": "7B stronger local model"
},
"gemini-2.0-flash": {
"size": 6,
"tokens_sec": 170,
"type": "api",
"description": "Fast and efficient API model"
},
"gemini-2.5-pro-preview-03-25": {
"size": 10,
"tokens_sec": 148,
"type": "api",
"description": "High-reasoning API model"
},
"gemini-1.5-flash": {
"size": 7,
"tokens_sec": 190,
"type": "api",
"description": "Fast general-purpose model"
},
"gemini-2.0-flash-lite": {
"size": 5,
"tokens_sec": 208,
"type": "api",
"description": "Low-latency, cost-efficient API model"
},
"gemini-2.0-flash-live-001": {
"size": 9,
"tokens_sec": 190,
"type": "api",
"description": "Voice/video low-latency API model"
}
}
models = {}
for name, model in raw_models.items():
is_api = model["type"] == "api"
if is_api:
# Flat cost for all API models
create_cost, invoke_cost = 20, 50
else:
create_cost = self.estimate_create_cost(model["size"], is_api=False)
invoke_cost = self.estimate_invoke_cost(model["tokens_sec"], is_api=False)
models[name] = ModelInfo(
name=name,
size=model["size"],
tokens_sec=model["tokens_sec"],
type=model["type"],
description=model["description"],
create_cost=create_cost,
invoke_cost=invoke_cost
)
return models
def get_filtered_models(self) -> Dict[str, ModelInfo]:
"""Return only models that match the current runtime."""
if self.env in ["gpu", "cpu-local"]:
return {k: v for k, v in self.models.items() if v.type == "local"}
else:
return {k: v for k, v in self.models.items() if v.type == "api"}
def get_all_models(self) -> Dict[str, ModelInfo]:
"""Return all models regardless of runtime."""
return self.models
if __name__ == "__main__":
registry = ModelRegistry()
print(f"[INFO] Detected runtime: {registry.env}\n")
print("Filtered models based on environment:")
for name, model in registry.get_filtered_models().items():
print(f"{name}: create={model.create_cost}, invoke={model.invoke_cost}, type={model.type}")