Spaces:
Paused
Paused
import os | |
from fastapi import FastAPI, HTTPException | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
from pydantic import BaseModel | |
import traceback | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.chains import ConversationChain | |
from langchain.prompts import PromptTemplate | |
from starlette.responses import StreamingResponse | |
import asyncio | |
import json | |
from langchain_community.llls import HuggingFacePipeline | |
import uvicorn | |
app = FastAPI() | |
# Get the Hugging Face API token from environment variables (BEST PRACTICE) | |
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if HUGGINGFACEHUB_API_TOKEN is None: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
# --- UPDATED: Use Llama 3.1 8B Instruct model --- | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, # torch.bfloat16 is generally good for Llama, can try torch.float16 if issues | |
trust_remote_code=True, | |
token=HUGGINGFACEHUB_API_TOKEN | |
) | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
elif torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
model.to(device) | |
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total) | |
memory = ConversationBufferWindowMemory(k=5) | |
# Initialize Langchain HuggingFacePipeline | |
llm = HuggingFacePipeline(pipeline=pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, # Allows for longer, detailed answers when required | |
# --- IMPORTANT FIX: Set return_full_text to True and handle slicing manually --- | |
return_full_text=True, | |
temperature=0.2, # Controls randomness (0.0 for deterministic, 1.0 for very creative) | |
do_sample=True, # Enable sampling for more varied outputs | |
# --- IMPORTANT FIX: REMOVED stop_sequence from pipeline initialization --- | |
# This prevents the TypeError and we handle stopping manually below. | |
)) | |
# --- UPDATED PROMPT TEMPLATE --- | |
# Using the recommended chat format for Llama models and explicit instructions. | |
template = """<|im_start|>system | |
You are a concise and direct AI assistant named Siddhi. | |
You strictly avoid asking any follow-up questions. | |
You do not generate any additional conversational turns (e.g., "Human: ..."). | |
If asked for your name, you respond with "I am Siddhi." | |
If you do not know the answer to a question, you truthfully state that it does not know. | |
<|im_end|> | |
<|im_start|>user | |
{history} | |
{input}<|im_end|> | |
<|im_start|>assistant | |
""" | |
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template) | |
# Initialize Langchain ConversationChain | |
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True) | |
class QuestionRequest(BaseModel): | |
question: str | |
class ChatResponse(BaseModel): | |
response: str | |
async def generate_text(request: QuestionRequest): | |
async def generate_stream(): | |
# Flag to indicate when we've started streaming the AI's actual response | |
started_streaming_ai_response = False | |
try: | |
response_stream = conversation.stream({"input": request.question}) | |
# Define stop sequences for manual checking | |
stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"] | |
assistant_start_marker = "<|im_start|>assistant\n" # Marker from the prompt template | |
for chunk in response_stream: | |
full_text_chunk = "" | |
if 'response' in chunk: | |
full_text_chunk = chunk['response'] | |
else: | |
full_text_chunk = str(chunk) # Fallback for unexpected chunk format | |
# Logic to extract only the AI's response from the full text chunk | |
if not started_streaming_ai_response: | |
if assistant_start_marker in full_text_chunk: | |
# Split the chunk at the assistant's start marker and take the part after it | |
token_content = full_text_chunk.split(assistant_start_marker, 1)[1] | |
started_streaming_ai_response = True | |
else: | |
# If the marker is not yet in the chunk, this chunk is still part of the prompt. | |
# We don't yield anything yet. | |
token_content = "" | |
else: | |
# Once we've started, all subsequent chunks are AI's response | |
token_content = full_text_chunk | |
# --- Manual stopping logic --- | |
# Check if the generated content contains a stop sequence. | |
# If it does, truncate the content and break the loop. | |
for stop_seq in stop_sequences_to_check: | |
if stop_seq in token_content: | |
token_content = token_content.split(stop_seq, 1)[0] # Truncate at the stop sequence | |
if token_content: # Yield any content before stop sequence | |
yield json.dumps({"content": token_content}) + "\n" | |
await asyncio.sleep(0.01) | |
yield json.dumps({"status": "completed"}) + "\n" # Signal completion | |
return # Exit the generator function | |
# Only yield if there's actual content to send after processing | |
if token_content: | |
yield json.dumps({"content": token_content}) + "\n" | |
await asyncio.sleep(0.01) | |
# Send a final completion message if the stream finishes naturally | |
yield json.dumps({"status": "completed"}) + "\n" | |
except Exception as e: | |
print("Error during streaming generation:") | |
traceback.print_exc() | |
# Yield error message in JSON format | |
yield json.dumps({"error": str(e)}) + "\n" | |
# Return a StreamingResponse with application/json media type | |
return StreamingResponse(generate_stream(), media_type="application/json") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |