saisha09 commited on
Commit
705389b
·
1 Parent(s): 344c9c4

model_cost file added

Browse files
Files changed (1) hide show
  1. src/models_cost.py +109 -0
src/models_cost.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+ from manager.utils.runtime_selector import detect_runtime_environment
4
+
5
+ @dataclass
6
+ class ModelInfo:
7
+ name: str
8
+ size: float
9
+ tokens_sec: int
10
+ type: str
11
+ description: str
12
+ create_cost: int = 0
13
+ invoke_cost: int = 0
14
+
15
+ class ModelRegistry:
16
+ def __init__(self):
17
+ self.env = detect_runtime_environment()
18
+ self.models = self._build_model_registry()
19
+
20
+ def estimate_create_cost(self, size: float, is_api: bool) -> int:
21
+ return int(size * (10 if is_api else 5))
22
+
23
+ def estimate_invoke_cost(self, tokens_sec: int, is_api: bool) -> int:
24
+ base_cost = 40 if is_api else 20
25
+ return base_cost + max(0, 60 - tokens_sec)
26
+
27
+ def _build_model_registry(self) -> Dict[str, ModelInfo]:
28
+ raw_models = {
29
+ "llama3.2": {
30
+ "size": 3,
31
+ "tokens_sec": 30,
32
+ "type": "local",
33
+ "description": "3B lightweight local model"
34
+ },
35
+ "mistral": {
36
+ "size": 7,
37
+ "tokens_sec": 50,
38
+ "type": "local",
39
+ "description": "7B stronger local model"
40
+ },
41
+ "gemini-2.0-flash": {
42
+ "size": 6,
43
+ "tokens_sec": 60,
44
+ "type": "api",
45
+ "description": "Fast and efficient API model"
46
+ },
47
+ "gemini-2.5-pro-preview-03-25": {
48
+ "size": 10,
49
+ "tokens_sec": 45,
50
+ "type": "api",
51
+ "description": "High-reasoning API model"
52
+ },
53
+ "gemini-1.5-flash": {
54
+ "size": 7,
55
+ "tokens_sec": 55,
56
+ "type": "api",
57
+ "description": "Fast general-purpose model"
58
+ },
59
+ "gemini-2.0-flash-lite": {
60
+ "size": 5,
61
+ "tokens_sec": 58,
62
+ "type": "api",
63
+ "description": "Low-latency, cost-efficient API model"
64
+ },
65
+ "gemini-2.0-flash-live-001": {
66
+ "size": 9,
67
+ "tokens_sec": 52,
68
+ "type": "api",
69
+ "description": "Voice/video low-latency API model"
70
+ }
71
+ }
72
+
73
+ models = {}
74
+ for name, model in raw_models.items():
75
+ is_api = model["type"] == "api"
76
+ create_cost = self.estimate_create_cost(model["size"], is_api)
77
+ invoke_cost = self.estimate_invoke_cost(model["tokens_sec"], is_api)
78
+
79
+ models[name] = ModelInfo(
80
+ name=name,
81
+ size=model["size"],
82
+ tokens_sec=model["tokens_sec"],
83
+ type=model["type"],
84
+ description=model["description"],
85
+ create_cost=create_cost,
86
+ invoke_cost=invoke_cost
87
+ )
88
+
89
+ return models
90
+
91
+ def get_filtered_models(self) -> Dict[str, ModelInfo]:
92
+ """Return only models that match the current runtime."""
93
+ if self.env in ["gpu", "cpu-local"]:
94
+ return {k: v for k, v in self.models.items() if v.type == "local"}
95
+ else:
96
+ return {k: v for k, v in self.models.items() if v.type == "api"}
97
+
98
+ def get_all_models(self) -> Dict[str, ModelInfo]:
99
+ """Return all models regardless of runtime."""
100
+ return self.models
101
+
102
+
103
+ if __name__ == "__main__":
104
+ registry = ModelRegistry()
105
+ print(f"[INFO] Detected runtime: {registry.env}\n")
106
+
107
+ print("Filtered models based on environment:")
108
+ for name, model in registry.get_filtered_models().items():
109
+ print(f"{name}: create={model.create_cost}, invoke={model.invoke_cost}, type={model.type}")