ParthSadaria commited on
Commit
f99a3b8
·
verified ·
1 Parent(s): 04e0db8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +10 -7
main.py CHANGED
@@ -496,7 +496,6 @@ async def search_gpt(q: str, request: Request, stream: bool = False, systempromp
496
  return JSONResponse(content=response_data)
497
 
498
  header_url = os.getenv('HEADER_URL')
499
-
500
  @app.post("/chat/completions")
501
  @app.post("/api/v1/chat/completions")
502
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
@@ -516,16 +515,20 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
516
 
517
  usage_tracker.record_request(request=request, model=model_to_use, endpoint="/chat/completions")
518
 
519
- payload_dict = payload.dict()
520
- payload_dict["model"] = model_to_use
521
 
522
- # --- Start of the fix ---
 
523
  if payload.tools is not None:
524
  payload_dict["tools"] = payload.tools
 
 
525
  if payload.tool_choice is not None:
526
- payload_dict["tool_choice"] = payload.tool_choice
527
- # --- End of the fix ---
528
-
 
 
529
 
530
  stream_enabled = payload_dict.get("stream", True)
531
 
 
496
  return JSONResponse(content=response_data)
497
 
498
  header_url = os.getenv('HEADER_URL')
 
499
  @app.post("/chat/completions")
500
  @app.post("/api/v1/chat/completions")
501
  async def get_completion(payload: Payload, request: Request, authenticated: bool = Depends(verify_api_key)):
 
515
 
516
  usage_tracker.record_request(request=request, model=model_to_use, endpoint="/chat/completions")
517
 
518
+ payload_dict = payload.dict(exclude_none=True) # Exclude keys with None values
 
519
 
520
+ # The payload.dict(exclude_none=True) already handles this.
521
+ # The following checks are now redundant but can be kept for explicit clarity.
522
  if payload.tools is not None:
523
  payload_dict["tools"] = payload.tools
524
+
525
+ # Handle the tool_choice more robustly
526
  if payload.tool_choice is not None:
527
+ # Check if the value is valid before passing it on
528
+ if isinstance(payload.tool_choice, (str, dict)):
529
+ payload_dict["tool_choice"] = payload.tool_choice
530
+ else:
531
+ print(f"Warning: tool_choice received with invalid type: {type(payload.tool_choice)}. Skipping.")
532
 
533
  stream_enabled = payload_dict.get("stream", True)
534