Spaces:
Sleeping
Sleeping
Update routes/chatCompletion.py
Browse files- routes/chatCompletion.py +40 -40
routes/chatCompletion.py
CHANGED
@@ -1,41 +1,41 @@
|
|
1 |
-
from fastapi import APIRouter
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
-
from models.chat_completion import ChatRequest
|
4 |
-
from huggingface_hub import InferenceClient
|
5 |
-
import json
|
6 |
-
|
7 |
-
router = APIRouter()
|
8 |
-
|
9 |
-
def generate_stream(response):
|
10 |
-
for chunk in response:
|
11 |
-
yield f"data: {json.dumps(chunk.__dict__, separators=(',', ':'))}\n\n"
|
12 |
-
|
13 |
-
@router.post("/v1/chat
|
14 |
-
async def chat_completion(body: ChatRequest):
|
15 |
-
client = InferenceClient(model=body.model)
|
16 |
-
|
17 |
-
res = client.chat_completion(
|
18 |
-
messages=body.messages,
|
19 |
-
frequency_penalty=body.frequency_penalty,
|
20 |
-
logit_bias=body.logit_bias,
|
21 |
-
logprobs=body.logprobs,
|
22 |
-
max_tokens=body.max_tokens,
|
23 |
-
n=body.n,
|
24 |
-
presence_penalty=body.presence_penalty,
|
25 |
-
response_format=body.response_format,
|
26 |
-
seed=body.seed,
|
27 |
-
stop=body.stop,
|
28 |
-
stream=body.stream,
|
29 |
-
stream_options=body.stream_options,
|
30 |
-
temperature=body.temperature,
|
31 |
-
top_logprobs=body.top_logprobs,
|
32 |
-
top_p=body.top_p,
|
33 |
-
tool_choice=body.tool_choice,
|
34 |
-
tool_prompt=body.tool_prompt,
|
35 |
-
tools=body.tools
|
36 |
-
)
|
37 |
-
|
38 |
-
if not body.stream:
|
39 |
-
return json.dumps(res.__dict__, indent=2)
|
40 |
-
else:
|
41 |
return StreamingResponse(generate_stream(res), media_type="text/event-stream")
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
from models.chat_completion import ChatRequest
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
import json
|
6 |
+
|
7 |
+
router = APIRouter()
|
8 |
+
|
9 |
+
def generate_stream(response):
|
10 |
+
for chunk in response:
|
11 |
+
yield f"data: {json.dumps(chunk.__dict__, separators=(',', ':'))}\n\n"
|
12 |
+
|
13 |
+
@router.post("/v1/chat/completions", tags=["Chat Completion"])
|
14 |
+
async def chat_completion(body: ChatRequest):
|
15 |
+
client = InferenceClient(model=body.model)
|
16 |
+
|
17 |
+
res = client.chat_completion(
|
18 |
+
messages=body.messages,
|
19 |
+
frequency_penalty=body.frequency_penalty,
|
20 |
+
logit_bias=body.logit_bias,
|
21 |
+
logprobs=body.logprobs,
|
22 |
+
max_tokens=body.max_tokens,
|
23 |
+
n=body.n,
|
24 |
+
presence_penalty=body.presence_penalty,
|
25 |
+
response_format=body.response_format,
|
26 |
+
seed=body.seed,
|
27 |
+
stop=body.stop,
|
28 |
+
stream=body.stream,
|
29 |
+
stream_options=body.stream_options,
|
30 |
+
temperature=body.temperature,
|
31 |
+
top_logprobs=body.top_logprobs,
|
32 |
+
top_p=body.top_p,
|
33 |
+
tool_choice=body.tool_choice,
|
34 |
+
tool_prompt=body.tool_prompt,
|
35 |
+
tools=body.tools
|
36 |
+
)
|
37 |
+
|
38 |
+
if not body.stream:
|
39 |
+
return json.dumps(res.__dict__, indent=2)
|
40 |
+
else:
|
41 |
return StreamingResponse(generate_stream(res), media_type="text/event-stream")
|