Spaces:
Paused
Paused
dharmendra
commited on
Commit
·
d00f229
1
Parent(s):
0b5b6d7
using Llama 3.1 8B instruct
Browse files
app.py
CHANGED
@@ -4,72 +4,80 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
4 |
import torch
|
5 |
from pydantic import BaseModel
|
6 |
import traceback
|
7 |
-
#from langchain.memory import ConversationBufferMemory
|
8 |
from langchain.memory import ConversationBufferWindowMemory
|
9 |
from langchain.chains import ConversationChain
|
10 |
-
from starlette.responses import StreamingResponse # <-- NEW IMPORT
|
11 |
-
import asyncio
|
12 |
-
from langchain_community.llms import HuggingFacePipeline
|
13 |
-
import json
|
14 |
from langchain.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
app = FastAPI()
|
|
|
17 |
# Get the Hugging Face API token from environment variables (BEST PRACTICE)
|
18 |
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
19 |
|
20 |
if HUGGINGFACEHUB_API_TOKEN is None:
|
21 |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
|
22 |
|
|
|
|
|
23 |
|
24 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
25 |
model = AutoModelForCausalLM.from_pretrained(
|
26 |
-
|
27 |
-
device_map="auto",
|
28 |
-
torch_dtype=torch.bfloat16,
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
if torch.backends.mps.is_available():
|
34 |
-
|
35 |
elif torch.cuda.is_available():
|
36 |
-
|
37 |
-
else
|
38 |
-
|
39 |
|
40 |
model.to(device)
|
|
|
41 |
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
|
42 |
memory = ConversationBufferWindowMemory(k=5)
|
43 |
|
44 |
# Initialize Langchain HuggingFacePipeline
|
45 |
llm = HuggingFacePipeline(pipeline=pipeline(
|
46 |
-
"text-generation",
|
47 |
-
model=model,
|
48 |
tokenizer=tokenizer,
|
49 |
-
max_new_tokens=512, #
|
50 |
-
|
|
|
51 |
temperature=0.2, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
|
52 |
-
do_sample=True # Enable sampling for more varied outputs
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
{history}
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
|
68 |
|
69 |
-
|
70 |
# Initialize Langchain ConversationChain
|
71 |
-
|
72 |
-
conversation = ConversationChain(llm=llm, memory=memory, prompt = PROMPT, verbose=True)
|
73 |
|
74 |
class QuestionRequest(BaseModel):
|
75 |
question: str
|
@@ -80,42 +88,65 @@ class ChatResponse(BaseModel):
|
|
80 |
@app.post("/api/generate")
|
81 |
async def generate_text(request: QuestionRequest):
|
82 |
async def generate_stream():
|
|
|
|
|
|
|
83 |
try:
|
84 |
-
# Use LangChain's .stream() method for token-by-token generation
|
85 |
-
# This will yield chunks of the response as they are produced
|
86 |
response_stream = conversation.stream({"input": request.question})
|
87 |
|
|
|
|
|
|
|
|
|
88 |
for chunk in response_stream:
|
89 |
-
|
90 |
-
# Each chunk is typically a dictionary with a 'content' key
|
91 |
-
# We want to send just the new token/text back.
|
92 |
-
# Ensure the chunk is stringified and followed by a newline for client parsing.
|
93 |
-
# For more robust streaming, consider Server-Sent Events (SSE) format:
|
94 |
-
# yield f"data: {json.dumps({'token': chunk.content})}\n\n"
|
95 |
-
# For simplicity, we'll just yield the content directly for now.
|
96 |
if 'response' in chunk:
|
97 |
-
|
98 |
else:
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
108 |
except Exception as e:
|
109 |
print("Error during streaming generation:")
|
110 |
traceback.print_exc()
|
111 |
-
#
|
112 |
-
yield
|
113 |
|
114 |
-
# Return a StreamingResponse
|
115 |
-
# media_type can be "text/event-stream" for SSE, or "text/plain" for simple newline-delimited text.
|
116 |
-
# For simplicity, we'll start with "text/plain" for easier initial client parsing.
|
117 |
return StreamingResponse(generate_stream(), media_type="application/json")
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
4 |
import torch
|
5 |
from pydantic import BaseModel
|
6 |
import traceback
|
|
|
7 |
from langchain.memory import ConversationBufferWindowMemory
|
8 |
from langchain.chains import ConversationChain
|
|
|
|
|
|
|
|
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
+
from starlette.responses import StreamingResponse
|
11 |
+
import asyncio
|
12 |
+
import json
|
13 |
+
from langchain_community.llls import HuggingFacePipeline
|
14 |
+
import uvicorn
|
15 |
|
16 |
app = FastAPI()
|
17 |
+
|
18 |
# Get the Hugging Face API token from environment variables (BEST PRACTICE)
|
19 |
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
20 |
|
21 |
if HUGGINGFACEHUB_API_TOKEN is None:
|
22 |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
|
23 |
|
24 |
+
# --- UPDATED: Use Llama 3.1 8B Instruct model ---
|
25 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
26 |
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
29 |
+
model_id,
|
30 |
+
device_map="auto",
|
31 |
+
torch_dtype=torch.bfloat16, # torch.bfloat16 is generally good for Llama, can try torch.float16 if issues
|
32 |
+
trust_remote_code=True,
|
33 |
+
token=HUGGINGFACEHUB_API_TOKEN
|
34 |
+
)
|
35 |
|
36 |
if torch.backends.mps.is_available():
|
37 |
+
device = "mps"
|
38 |
elif torch.cuda.is_available():
|
39 |
+
device = "cuda"
|
40 |
+
else:
|
41 |
+
device = "cpu"
|
42 |
|
43 |
model.to(device)
|
44 |
+
|
45 |
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
|
46 |
memory = ConversationBufferWindowMemory(k=5)
|
47 |
|
48 |
# Initialize Langchain HuggingFacePipeline
|
49 |
llm = HuggingFacePipeline(pipeline=pipeline(
|
50 |
+
"text-generation",
|
51 |
+
model=model,
|
52 |
tokenizer=tokenizer,
|
53 |
+
max_new_tokens=512, # Allows for longer, detailed answers when required
|
54 |
+
# --- IMPORTANT FIX: Set return_full_text to True and handle slicing manually ---
|
55 |
+
return_full_text=True,
|
56 |
temperature=0.2, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
|
57 |
+
do_sample=True, # Enable sampling for more varied outputs
|
58 |
+
# --- IMPORTANT FIX: REMOVED stop_sequence from pipeline initialization ---
|
59 |
+
# This prevents the TypeError and we handle stopping manually below.
|
60 |
+
))
|
61 |
+
|
62 |
+
# --- UPDATED PROMPT TEMPLATE ---
|
63 |
+
# Using the recommended chat format for Llama models and explicit instructions.
|
64 |
+
template = """<|im_start|>system
|
65 |
+
You are a concise and direct AI assistant named Siddhi.
|
66 |
+
You strictly avoid asking any follow-up questions.
|
67 |
+
You do not generate any additional conversational turns (e.g., "Human: ...").
|
68 |
+
If asked for your name, you respond with "I am Siddhi."
|
69 |
+
If you do not know the answer to a question, you truthfully state that it does not know.
|
70 |
+
<|im_end|>
|
71 |
+
<|im_start|>user
|
72 |
{history}
|
73 |
+
{input}<|im_end|>
|
74 |
+
<|im_start|>assistant
|
75 |
+
"""
|
76 |
|
77 |
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
|
78 |
|
|
|
79 |
# Initialize Langchain ConversationChain
|
80 |
+
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
|
|
|
81 |
|
82 |
class QuestionRequest(BaseModel):
|
83 |
question: str
|
|
|
88 |
@app.post("/api/generate")
|
89 |
async def generate_text(request: QuestionRequest):
|
90 |
async def generate_stream():
|
91 |
+
# Flag to indicate when we've started streaming the AI's actual response
|
92 |
+
started_streaming_ai_response = False
|
93 |
+
|
94 |
try:
|
|
|
|
|
95 |
response_stream = conversation.stream({"input": request.question})
|
96 |
|
97 |
+
# Define stop sequences for manual checking
|
98 |
+
stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
|
99 |
+
assistant_start_marker = "<|im_start|>assistant\n" # Marker from the prompt template
|
100 |
+
|
101 |
for chunk in response_stream:
|
102 |
+
full_text_chunk = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if 'response' in chunk:
|
104 |
+
full_text_chunk = chunk['response']
|
105 |
else:
|
106 |
+
full_text_chunk = str(chunk) # Fallback for unexpected chunk format
|
107 |
+
|
108 |
+
# Logic to extract only the AI's response from the full text chunk
|
109 |
+
if not started_streaming_ai_response:
|
110 |
+
if assistant_start_marker in full_text_chunk:
|
111 |
+
# Split the chunk at the assistant's start marker and take the part after it
|
112 |
+
token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
|
113 |
+
started_streaming_ai_response = True
|
114 |
+
else:
|
115 |
+
# If the marker is not yet in the chunk, this chunk is still part of the prompt.
|
116 |
+
# We don't yield anything yet.
|
117 |
+
token_content = ""
|
118 |
+
else:
|
119 |
+
# Once we've started, all subsequent chunks are AI's response
|
120 |
+
token_content = full_text_chunk
|
121 |
+
|
122 |
+
# --- Manual stopping logic ---
|
123 |
+
# Check if the generated content contains a stop sequence.
|
124 |
+
# If it does, truncate the content and break the loop.
|
125 |
+
for stop_seq in stop_sequences_to_check:
|
126 |
+
if stop_seq in token_content:
|
127 |
+
token_content = token_content.split(stop_seq, 1)[0] # Truncate at the stop sequence
|
128 |
+
if token_content: # Yield any content before stop sequence
|
129 |
+
yield json.dumps({"content": token_content}) + "\n"
|
130 |
+
await asyncio.sleep(0.01)
|
131 |
+
yield json.dumps({"status": "completed"}) + "\n" # Signal completion
|
132 |
+
return # Exit the generator function
|
133 |
+
|
134 |
+
# Only yield if there's actual content to send after processing
|
135 |
+
if token_content:
|
136 |
+
yield json.dumps({"content": token_content}) + "\n"
|
137 |
+
await asyncio.sleep(0.01)
|
138 |
+
|
139 |
+
# Send a final completion message if the stream finishes naturally
|
140 |
+
yield json.dumps({"status": "completed"}) + "\n"
|
141 |
|
|
|
142 |
except Exception as e:
|
143 |
print("Error during streaming generation:")
|
144 |
traceback.print_exc()
|
145 |
+
# Yield error message in JSON format
|
146 |
+
yield json.dumps({"error": str(e)}) + "\n"
|
147 |
|
148 |
+
# Return a StreamingResponse with application/json media type
|
|
|
|
|
149 |
return StreamingResponse(generate_stream(), media_type="application/json")
|
150 |
|
151 |
+
if __name__ == "__main__":
|
152 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|
|