Spaces:
Paused
Paused
import os | |
from fastapi import FastAPI, HTTPException | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
from pydantic import BaseModel | |
import traceback | |
#from langchain.memory import ConversationBufferMemory | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.chains import ConversationChain | |
from langchain_community.llms import HuggingFacePipeline | |
app = FastAPI() | |
# Get the Hugging Face API token from environment variables (BEST PRACTICE) | |
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if HUGGINGFACEHUB_API_TOKEN is None: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") | |
model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-1.5B-Instruct", | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
token=HUGGINGFACEHUB_API_TOKEN) | |
#print(f"Tokenizer attributes: {dir(tokenizer)}") | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
elif torch.cuda.is_available(): | |
device= "cuda" | |
else : | |
device = "cpu" | |
model.to(device) | |
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total) | |
memory = ConversationBufferWindowMemory(k=5) | |
# Initialize Langchain HuggingFacePipeline | |
llm = HuggingFacePipeline(pipeline=pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, # Adjust as needed for desired response length | |
return_full_text=False, # Crucial for getting only the AI's response, esp when ans is small | |
temperature=0.7, # Controls randomness (0.0 for deterministic, 1.0 for very creative) | |
do_sample=True # Enable sampling for more varied outputs | |
)) | |
# Initialize Langchain ConversationChain | |
# verbose=True for debugging LangChain's pro | |
conversation = ConversationChain(llm=llm, memory=memory,verbose=True) | |
class QuestionRequest(BaseModel): | |
question: str | |
class ChatResponse(BaseModel): | |
response: str | |
async def generate_text(request: QuestionRequest): | |
async def generate_stream(): | |
try: | |
# Use LangChain's .stream() method for token-by-token generation | |
# This will yield chunks of the response as they are produced | |
response_stream = conversation.stream({"input": request.question}) | |
for chunk in response_stream: | |
# Each chunk is typically a dictionary with a 'content' key | |
# We want to send just the new token/text back. | |
# Ensure the chunk is stringified and followed by a newline for client parsing. | |
# For more robust streaming, consider Server-Sent Events (SSE) format: | |
# yield f"data: {json.dumps({'token': chunk.content})}\n\n" | |
# For simplicity, we'll just yield the content directly for now. | |
yield chunk.content | |
await asyncio.sleep(0.01) # Small delay to allow client to process chunks | |
except Exception as e: | |
print("Error during streaming generation:") | |
traceback.print_exc() | |
# You might want to yield an error message to the client here | |
yield f"ERROR: {str(e)}\n" | |
# Return a StreamingResponse, which will send chunks as they are yielded by generate_stream() | |
# media_type can be "text/event-stream" for SSE, or "text/plain" for simple newline-delimited text. | |
# For simplicity, we'll start with "text/plain" for easier initial client parsing. | |
return StreamingResponse(generate_stream(), media_type="text/plain") | |
# below when not using langchain fully | |
# try: | |
# # Retrieve history | |
# history = memory.load_memory_variables({})['history'] | |
# # Create prompt with history and current question | |
# prompt = f"History:\n{history}\nQuestion: {request.question}\nAnswer:" | |
# inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) | |
# with torch.no_grad(): | |
# outputs = model.generate( | |
# inputs=inputs['input_ids'], # Pass the 'input_ids' tensor | |
# attention_mask=inputs['attention_mask'], | |
# max_length=300, | |
# num_beams=5, | |
# no_repeat_ngram_size=2, | |
# temperature=0.7, | |
# top_k=50, | |
# top_p=0.95, | |
# do_sample=True, | |
# eos_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>"), | |
# pad_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>") | |
# ) | |
# response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# return {"response": response} | |
# except Exception as e: | |
# print("Error during generation:") | |
# traceback.print_exc() | |
# raise HTTPException(status_code=500, detail=str(e)) | |