ParthSadaria commited on
Commit
72b5133
·
verified ·
1 Parent(s): 81400df

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -6
main.py CHANGED
@@ -12,7 +12,7 @@ import json
12
  from typing import Optional
13
  import datetime
14
 
15
- load_dotenv()
16
 
17
  app = FastAPI()
18
 
@@ -31,6 +31,8 @@ if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoi
31
  # Define models that should use the secondary endpoint
32
  alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
33
 
 
 
34
  class Payload(BaseModel):
35
  model: str
36
  messages: list
@@ -141,11 +143,20 @@ async def get_models():
141
  async def fetch_models():
142
  return await get_models()
143
 
144
- available_model_ids = [
145
- "gpt-4o", "gpt-4o-mini", "claude-3-haiku", "llama-3.1-405b", "llama-3.1-70b",
146
- "llama-3.1-8b", "gemini-1.5-flash", "mixtral-8x7b" , "command-r","gemini-pro",
147
- "gpt-3.5-turbo", "command","claude-sonnet-3.5"
148
- ]
 
 
 
 
 
 
 
 
 
149
  @app.post("/chat/completions")
150
  @app.post("/v1/chat/completions")
151
  async def get_completion(payload: Payload,request: Request):
@@ -304,6 +315,9 @@ async def playground():
304
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
305
  @app.on_event("startup")
306
  async def startup_event():
 
 
 
307
  print("API endpoints:")
308
  print("GET /")
309
  print("GET /models")
 
12
  from typing import Optional
13
  import datetime
14
 
15
+ load_dotenv() #idk why this shi
16
 
17
  app = FastAPI()
18
 
 
31
  # Define models that should use the secondary endpoint
32
  alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
33
 
34
+ available_model_ids = []
35
+
36
  class Payload(BaseModel):
37
  model: str
38
  messages: list
 
143
  async def fetch_models():
144
  return await get_models()
145
 
146
+ def load_model_ids(json_file_path):
147
+ try:
148
+ with open(json_file_path, 'r') as f:
149
+ models_data = json.load(f)
150
+ # Extract 'id' from each model object
151
+ model_ids = [model['id'] for model in models_data if 'id' in model]
152
+ return model_ids
153
+ except FileNotFoundError:
154
+ print("Error: models.json file not found.")
155
+ return []
156
+ except json.JSONDecodeError:
157
+ print("Error: Invalid JSON format in models.json.")
158
+ return []
159
+
160
  @app.post("/chat/completions")
161
  @app.post("/v1/chat/completions")
162
  async def get_completion(payload: Payload,request: Request):
 
315
  return HTMLResponse(content="<h1>playground.html not found</h1>", status_code=404)
316
  @app.on_event("startup")
317
  async def startup_event():
318
+ global available_model_ids
319
+ available_model_ids = load_model_ids("models.json")
320
+ print(f"Loaded model IDs: {available_model_ids}")
321
  print("API endpoints:")
322
  print("GET /")
323
  print("GET /models")