Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from langchain_ollama import OllamaLLM | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import TextLoader | |
from fastapi.middleware.cors import CORSMiddleware | |
import traceback | |
# from langchain_core.output_parsers import StrOutputParser | |
# from langchain_core.runnables import RunnablePassthrough | |
import os | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
# Load and split documents | |
loader = TextLoader("knowledge_base.txt", encoding="utf-8") | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", ".", "!", "?", "ุ", "ุ", "!", ";", ","], | |
) | |
splits = text_splitter.split_documents(documents) | |
# Generate embeddings and store in FAISS | |
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large") | |
vectorstore = FAISS.from_documents(splits, embeddings) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 5, "score_threshold": 0.4}) | |
# Define improved prompt template | |
template = """ | |
You are an AI assistant. You must ALWAYS respond in the EXACT SAME LANGUAGE as the user's question or message. This is crucial: | |
- If the user writes in English, you MUST respond in English | |
- If the user writes in Arabic, you MUST respond in Arabic (Modern Standard Arabic) | |
- Mixed language messages should get responses in the predominant language of the message | |
Conversation history: | |
{history} | |
Relevant information from knowledge base: | |
{context} | |
User's message: {question} | |
Key requirements: | |
1. MATCH THE LANGUAGE OF THE USER'S MESSAGE EXACTLY | |
2. Use the provided context and history to answer the question | |
3. Maintain your identity as an AI assistant | |
4. Never pretend to be the user or adopt their name | |
5. For greetings and casual conversation, respond naturally without using the knowledge base | |
6. Only use the knowledge base content when directly relevant to a specific question | |
Response: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
# Load model with adjusted parameters | |
model = OllamaLLM( | |
model="mistral", | |
temperature=0.1, | |
num_ctx=8192, | |
top_p=0.8, | |
) | |
def format_conversation_history(history): | |
formatted = "" | |
for entry in history: | |
formatted += f"{entry}\n" | |
return formatted | |
# Create RAG chain with properly handled input types | |
def generate_response(question, history, retriever): | |
# Get relevant documents | |
context = retriever.invoke(question) | |
context_str = "\n".join(doc.page_content for doc in context) | |
# Format the conversation history | |
history_str = format_conversation_history(history) | |
# Prepare the input for the prompt | |
chain_input = {"context": context_str, "history": history_str, "question": question} | |
# Generate response using the prompt template and model | |
response = prompt.format(**chain_input) | |
response = model.invoke(response) | |
return response | |
def chatbot_conversation(): | |
print("Hello! I'm an AI assistant. Type 'exit' to quit.") | |
conversation_history = [] | |
while True: | |
user_input = input("You: ").strip() | |
if user_input.lower() == 'exit': | |
break | |
try: | |
# Generate response | |
result = generate_response(user_input, conversation_history, retriever) | |
print(f"Assistant: {result}") | |
# Store the exchange in history | |
conversation_history.append(f"User: {user_input}") | |
conversation_history.append(f"Assistant: {result}") | |
except Exception as e: | |
print(f"An error occurred: {str(e)}") | |
print( | |
"Assistant: I apologize, but I encountered an error. Please try again." | |
) | |
chat_histories = {} | |
class ChatRequest(BaseModel): | |
user_id: str # Unique ID for tracking history per user | |
message: str | |
def chat(request: ChatRequest): | |
try: | |
# Retrieve the user's conversation history or create a new one | |
if request.user_id not in chat_histories: | |
chat_histories[request.user_id] = [] | |
# Get conversation history | |
history = chat_histories[request.user_id] | |
# Generate response | |
response = generate_response(request.message, history, retriever) | |
# Update history | |
history.append(f"User: {request.message}") | |
history.append(f"Assistant: {response}") | |
return {"response": response} | |
except Exception as e: | |
print(traceback.format_exc()) | |
raise HTTPException(status_code=500, detail=str(e)) |