PY_LLM_NEW / app.py
dharmendra
Implement streaming responses for LLM API
a05ac69
raw
history blame
4.95 kB
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from pydantic import BaseModel
import traceback
#from langchain.memory import ConversationBufferMemory
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from langchain_community.llms import HuggingFacePipeline
app = FastAPI()
# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HUGGINGFACEHUB_API_TOKEN is None:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-1.5B-Instruct",
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
token=HUGGINGFACEHUB_API_TOKEN)
#print(f"Tokenizer attributes: {dir(tokenizer)}")
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device= "cuda"
else :
device = "cpu"
model.to(device)
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
memory = ConversationBufferWindowMemory(k=5)
# Initialize Langchain HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512, # Adjust as needed for desired response length
return_full_text=False, # Crucial for getting only the AI's response, esp when ans is small
temperature=0.7, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
do_sample=True # Enable sampling for more varied outputs
))
# Initialize Langchain ConversationChain
# verbose=True for debugging LangChain's pro
conversation = ConversationChain(llm=llm, memory=memory,verbose=True)
class QuestionRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
response: str
@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
async def generate_stream():
try:
# Use LangChain's .stream() method for token-by-token generation
# This will yield chunks of the response as they are produced
response_stream = conversation.stream({"input": request.question})
for chunk in response_stream:
# Each chunk is typically a dictionary with a 'content' key
# We want to send just the new token/text back.
# Ensure the chunk is stringified and followed by a newline for client parsing.
# For more robust streaming, consider Server-Sent Events (SSE) format:
# yield f"data: {json.dumps({'token': chunk.content})}\n\n"
# For simplicity, we'll just yield the content directly for now.
yield chunk.content
await asyncio.sleep(0.01) # Small delay to allow client to process chunks
except Exception as e:
print("Error during streaming generation:")
traceback.print_exc()
# You might want to yield an error message to the client here
yield f"ERROR: {str(e)}\n"
# Return a StreamingResponse, which will send chunks as they are yielded by generate_stream()
# media_type can be "text/event-stream" for SSE, or "text/plain" for simple newline-delimited text.
# For simplicity, we'll start with "text/plain" for easier initial client parsing.
return StreamingResponse(generate_stream(), media_type="text/plain")
# below when not using langchain fully
# try:
# # Retrieve history
# history = memory.load_memory_variables({})['history']
# # Create prompt with history and current question
# prompt = f"History:\n{history}\nQuestion: {request.question}\nAnswer:"
# inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
# with torch.no_grad():
# outputs = model.generate(
# inputs=inputs['input_ids'], # Pass the 'input_ids' tensor
# attention_mask=inputs['attention_mask'],
# max_length=300,
# num_beams=5,
# no_repeat_ngram_size=2,
# temperature=0.7,
# top_k=50,
# top_p=0.95,
# do_sample=True,
# eos_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>"),
# pad_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>")
# )
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# return {"response": response}
# except Exception as e:
# print("Error during generation:")
# traceback.print_exc()
# raise HTTPException(status_code=500, detail=str(e))