PY_LLM_NEW / app.py
dharmendra
using Llama 3.1 8B instruct
d00f229
raw
history blame
6.35 kB
import os
from fastapi import FastAPI, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from pydantic import BaseModel
import traceback
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from langchain.prompts import PromptTemplate
from starlette.responses import StreamingResponse
import asyncio
import json
from langchain_community.llls import HuggingFacePipeline
import uvicorn
app = FastAPI()
# Get the Hugging Face API token from environment variables (BEST PRACTICE)
HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if HUGGINGFACEHUB_API_TOKEN is None:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
# --- UPDATED: Use Llama 3.1 8B Instruct model ---
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16, # torch.bfloat16 is generally good for Llama, can try torch.float16 if issues
trust_remote_code=True,
token=HUGGINGFACEHUB_API_TOKEN
)
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model.to(device)
# k=5 means it will keep the last 5 human-AI interaction pairs (10 messages total)
memory = ConversationBufferWindowMemory(k=5)
# Initialize Langchain HuggingFacePipeline
llm = HuggingFacePipeline(pipeline=pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512, # Allows for longer, detailed answers when required
# --- IMPORTANT FIX: Set return_full_text to True and handle slicing manually ---
return_full_text=True,
temperature=0.2, # Controls randomness (0.0 for deterministic, 1.0 for very creative)
do_sample=True, # Enable sampling for more varied outputs
# --- IMPORTANT FIX: REMOVED stop_sequence from pipeline initialization ---
# This prevents the TypeError and we handle stopping manually below.
))
# --- UPDATED PROMPT TEMPLATE ---
# Using the recommended chat format for Llama models and explicit instructions.
template = """<|im_start|>system
You are a concise and direct AI assistant named Siddhi.
You strictly avoid asking any follow-up questions.
You do not generate any additional conversational turns (e.g., "Human: ...").
If asked for your name, you respond with "I am Siddhi."
If you do not know the answer to a question, you truthfully state that it does not know.
<|im_end|>
<|im_start|>user
{history}
{input}<|im_end|>
<|im_start|>assistant
"""
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
# Initialize Langchain ConversationChain
conversation = ConversationChain(llm=llm, memory=memory, prompt=PROMPT, verbose=True)
class QuestionRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
response: str
@app.post("/api/generate")
async def generate_text(request: QuestionRequest):
async def generate_stream():
# Flag to indicate when we've started streaming the AI's actual response
started_streaming_ai_response = False
try:
response_stream = conversation.stream({"input": request.question})
# Define stop sequences for manual checking
stop_sequences_to_check = ["Human:", "AI:", "\nHuman:", "\nAI:", "<|im_end|>"]
assistant_start_marker = "<|im_start|>assistant\n" # Marker from the prompt template
for chunk in response_stream:
full_text_chunk = ""
if 'response' in chunk:
full_text_chunk = chunk['response']
else:
full_text_chunk = str(chunk) # Fallback for unexpected chunk format
# Logic to extract only the AI's response from the full text chunk
if not started_streaming_ai_response:
if assistant_start_marker in full_text_chunk:
# Split the chunk at the assistant's start marker and take the part after it
token_content = full_text_chunk.split(assistant_start_marker, 1)[1]
started_streaming_ai_response = True
else:
# If the marker is not yet in the chunk, this chunk is still part of the prompt.
# We don't yield anything yet.
token_content = ""
else:
# Once we've started, all subsequent chunks are AI's response
token_content = full_text_chunk
# --- Manual stopping logic ---
# Check if the generated content contains a stop sequence.
# If it does, truncate the content and break the loop.
for stop_seq in stop_sequences_to_check:
if stop_seq in token_content:
token_content = token_content.split(stop_seq, 1)[0] # Truncate at the stop sequence
if token_content: # Yield any content before stop sequence
yield json.dumps({"content": token_content}) + "\n"
await asyncio.sleep(0.01)
yield json.dumps({"status": "completed"}) + "\n" # Signal completion
return # Exit the generator function
# Only yield if there's actual content to send after processing
if token_content:
yield json.dumps({"content": token_content}) + "\n"
await asyncio.sleep(0.01)
# Send a final completion message if the stream finishes naturally
yield json.dumps({"status": "completed"}) + "\n"
except Exception as e:
print("Error during streaming generation:")
traceback.print_exc()
# Yield error message in JSON format
yield json.dumps({"error": str(e)}) + "\n"
# Return a StreamingResponse with application/json media type
return StreamingResponse(generate_stream(), media_type="application/json")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))