ParthSadaria commited on
Commit
dbffb4e
·
verified ·
1 Parent(s): c303b69

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -3
main.py CHANGED
@@ -501,6 +501,8 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
501
 
502
  env_vars = get_env_vars()
503
 
 
 
504
  if model_to_use in mistral_models:
505
  endpoint = env_vars['mistral_api']
506
  custom_headers = {
@@ -524,6 +526,7 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
524
  custom_headers = {
525
  "Authorization": f"Bearer {env_vars['gemini_key']}"
526
  }
 
527
  else:
528
  endpoint = env_vars['secret_api_endpoint']
529
  custom_headers = {
@@ -532,12 +535,12 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
532
  "Referer": header_url
533
  }
534
 
535
- print(f"Using endpoint: {endpoint} for model: {model_to_use}")
536
 
537
  async def real_time_stream_generator():
538
  try:
539
  async with httpx.AsyncClient(timeout=60.0) as client:
540
- async with client.stream("POST", f"{endpoint}/v1/chat/completions", json=payload_dict, headers=custom_headers) as response:
541
  if response.status_code >= 400:
542
  error_messages = {
543
  422: "Unprocessable entity. Check your payload.",
@@ -576,7 +579,6 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
576
  async for chunk in real_time_stream_generator():
577
  response_content.append(chunk)
578
  return JSONResponse(content=json.loads(''.join(response_content)))
579
-
580
  @app.post("/images/generations")
581
  async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
582
  if not server_status:
 
501
 
502
  env_vars = get_env_vars()
503
 
504
+ target_url_path = "/v1/chat/completions" # Default path
505
+
506
  if model_to_use in mistral_models:
507
  endpoint = env_vars['mistral_api']
508
  custom_headers = {
 
526
  custom_headers = {
527
  "Authorization": f"Bearer {env_vars['gemini_key']}"
528
  }
529
+ target_url_path = "/chat/completions" # Use /chat/completions for Gemini
530
  else:
531
  endpoint = env_vars['secret_api_endpoint']
532
  custom_headers = {
 
535
  "Referer": header_url
536
  }
537
 
538
+ print(f"Using endpoint: {endpoint} with path: {target_url_path} for model: {model_to_use}")
539
 
540
  async def real_time_stream_generator():
541
  try:
542
  async with httpx.AsyncClient(timeout=60.0) as client:
543
+ async with client.stream("POST", f"{endpoint}{target_url_path}", json=payload_dict, headers=custom_headers) as response:
544
  if response.status_code >= 400:
545
  error_messages = {
546
  422: "Unprocessable entity. Check your payload.",
 
579
  async for chunk in real_time_stream_generator():
580
  response_content.append(chunk)
581
  return JSONResponse(content=json.loads(''.join(response_content)))
 
582
  @app.post("/images/generations")
583
  async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
584
  if not server_status: