saisha09 commited on
Commit
82a76a7
·
1 Parent(s): 69dc898

updated model selector

Browse files
Files changed (1) hide show
  1. src/config/model_selector.py +25 -16
src/config/model_selector.py CHANGED
@@ -1,21 +1,30 @@
1
  from src.utils.runtime_selector import detect_runtime_environment
 
2
  import os
3
  from dotenv import load_dotenv
4
  load_dotenv()
5
 
6
- #Placeholders
7
- MODEL_MAP = {
8
- "gpu": "gemini-2.5-pro-exp-03-25",
9
- "cpu-local": "gemini-2.0-flash",
10
- "cloud-only": "gemini-2.0-flash"
11
- }
12
- def choose_best_model():
13
- runtime_env = detect_runtime_environment()
14
- print(f"[DEBUG] Runtime env: {runtime_env}")
15
- print(f"[DEBUG] API key exists: {'Yes' if os.environ.get('GEMINI_KEY') else 'No'}")
16
- if runtime_env == "cpu-local":
17
- if os.environ.get("GEMINI_KEY"):
18
- return "gemini-2.0-flash"
19
- else:
20
- print("[WARN] No GEMINI_KEY set, falling back to llama3.2.")
21
- return "llama3.2"
 
 
 
 
 
 
 
 
 
1
  from src.utils.runtime_selector import detect_runtime_environment
2
+ from cost_benefit import get_best_model
3
  import os
4
  from dotenv import load_dotenv
5
  load_dotenv()
6
 
7
+ def choose_best_model(return_full=False):
8
+ env = detect_runtime_environment()
9
+ print(f"[INFO] Runtime Environment: {env}")
10
+
11
+ weights = {
12
+ "w_size": 0.1,
13
+ "w_token_cost": 100,
14
+ "w_speed": 0.5
15
+ }
16
+
17
+ result = get_best_model(weights, env)
18
+
19
+ if isinstance(result, str) or not result.get("model"):
20
+ if env == "cpu-local":
21
+ if os.getenv("GEMINI_KEY"):
22
+ print("[INFO] Falling back to Gemini for cpu-local.")
23
+ return {"model": "gemini-2.0-flash"} if return_full else "gemini-2.0-flash"
24
+ else:
25
+ print("[WARN] GOOGLE_API_KEY missing. Falling back to llama3.2.")
26
+ return {"model": "llama3.2"} if return_full else "llama3.2"
27
+ return {"model": "llama3.2"} if return_full else "llama3.2"
28
+
29
+ print(f"[INFO] Auto-selected model: {result['model']}")
30
+ return result if return_full else result["model"]