Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
@@ -14,7 +14,6 @@ def get_chatbot_response(user_input: str, max_length=100):
|
|
14 |
if not user_input:
|
15 |
return "Please say something!"
|
16 |
|
17 |
-
# Encode the input and generate a response
|
18 |
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
|
19 |
chat_history_ids = model.generate(
|
20 |
input_ids,
|
@@ -26,7 +25,6 @@ def get_chatbot_response(user_input: str, max_length=100):
|
|
26 |
top_p=0.95,
|
27 |
temperature=0.8
|
28 |
)
|
29 |
-
# Decode the response, skipping the input part
|
30 |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
31 |
return response.strip()
|
32 |
|
@@ -180,4 +178,4 @@ async def chat_endpoint(data: dict):
|
|
180 |
|
181 |
if __name__ == "__main__":
|
182 |
import uvicorn
|
183 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
from fastapi.responses import HTMLResponse
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
|
|
|
14 |
if not user_input:
|
15 |
return "Please say something!"
|
16 |
|
|
|
17 |
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
|
18 |
chat_history_ids = model.generate(
|
19 |
input_ids,
|
|
|
25 |
top_p=0.95,
|
26 |
temperature=0.8
|
27 |
)
|
|
|
28 |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
29 |
return response.strip()
|
30 |
|
|
|
178 |
|
179 |
if __name__ == "__main__":
|
180 |
import uvicorn
|
181 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|