abdibrahem commited on
Commit
56623dc
·
verified ·
1 Parent(s): 7b12046

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +151 -0
main.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from langchain_ollama import OllamaLLM
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.document_loaders import TextLoader
9
+
10
+ import traceback
11
+
12
+ # from langchain_core.output_parsers import StrOutputParser
13
+ # from langchain_core.runnables import RunnablePassthrough
14
+
15
+ import os
16
+
17
+ os.environ["HF_HOME"] = "/tmp/huggingface"
18
+
19
+ app = FastAPI()
20
+
21
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
22
+
23
+ # Load and split documents
24
+ loader = TextLoader("knowledge_base.txt", encoding="utf-8")
25
+ documents = loader.load()
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size=500,
28
+ chunk_overlap=50,
29
+ separators=["\n\n", "\n", ".", "!", "?", "،", "؟", "!", ";", ","],
30
+ )
31
+ splits = text_splitter.split_documents(documents)
32
+
33
+ # Generate embeddings and store in FAISS
34
+ embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
35
+ vectorstore = FAISS.from_documents(splits, embeddings)
36
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 5, "score_threshold": 0.4})
37
+
38
+ # Define improved prompt template
39
+ template = """
40
+ You are an AI assistant. You must ALWAYS respond in the EXACT SAME LANGUAGE as the user's question or message. This is crucial:
41
+ - If the user writes in English, you MUST respond in English
42
+ - If the user writes in Arabic, you MUST respond in Arabic (Modern Standard Arabic)
43
+ - Mixed language messages should get responses in the predominant language of the message
44
+ Conversation history:
45
+ {history}
46
+ Relevant information from knowledge base:
47
+ {context}
48
+ User's message: {question}
49
+ Key requirements:
50
+ 1. MATCH THE LANGUAGE OF THE USER'S MESSAGE EXACTLY
51
+ 2. Use the provided context and history to answer the question
52
+ 3. Maintain your identity as an AI assistant
53
+ 4. Never pretend to be the user or adopt their name
54
+ 5. For greetings and casual conversation, respond naturally without using the knowledge base
55
+ 6. Only use the knowledge base content when directly relevant to a specific question
56
+ Response:
57
+ """
58
+
59
+ prompt = ChatPromptTemplate.from_template(template)
60
+
61
+ # Load model with adjusted parameters
62
+ model = OllamaLLM(
63
+ model="mistral",
64
+ temperature=0.1,
65
+ num_ctx=8192,
66
+ top_p=0.8,
67
+ )
68
+
69
+
70
+ def format_conversation_history(history):
71
+ formatted = ""
72
+ for entry in history:
73
+ formatted += f"{entry}\n"
74
+ return formatted
75
+
76
+
77
+ # Create RAG chain with properly handled input types
78
+ def generate_response(question, history, retriever):
79
+ # Get relevant documents
80
+ context = retriever.invoke(question)
81
+ context_str = "\n".join(doc.page_content for doc in context)
82
+
83
+ # Format the conversation history
84
+ history_str = format_conversation_history(history)
85
+
86
+ # Prepare the input for the prompt
87
+ chain_input = {"context": context_str, "history": history_str, "question": question}
88
+
89
+ # Generate response using the prompt template and model
90
+ response = prompt.format(**chain_input)
91
+ response = model.invoke(response)
92
+
93
+ return response
94
+
95
+
96
+ def chatbot_conversation():
97
+ print("Hello! I'm an AI assistant. Type 'exit' to quit.")
98
+
99
+ conversation_history = []
100
+
101
+ while True:
102
+ user_input = input("You: ").strip()
103
+ if user_input.lower() == 'exit':
104
+ break
105
+
106
+ try:
107
+ # Generate response
108
+ result = generate_response(user_input, conversation_history, retriever)
109
+
110
+ print(f"Assistant: {result}")
111
+
112
+ # Store the exchange in history
113
+ conversation_history.append(f"User: {user_input}")
114
+ conversation_history.append(f"Assistant: {result}")
115
+
116
+ except Exception as e:
117
+ print(f"An error occurred: {str(e)}")
118
+ print(
119
+ "Assistant: I apologize, but I encountered an error. Please try again."
120
+ )
121
+
122
+
123
+ chat_histories = {}
124
+
125
+
126
+ class ChatRequest(BaseModel):
127
+ user_id: str # Unique ID for tracking history per user
128
+ message: str
129
+
130
+
131
+ @app.post("/chat")
132
+ def chat(request: ChatRequest):
133
+ try:
134
+ # Retrieve the user's conversation history or create a new one
135
+ if request.user_id not in chat_histories:
136
+ chat_histories[request.user_id] = []
137
+
138
+ # Get conversation history
139
+ history = chat_histories[request.user_id]
140
+
141
+ # Generate response
142
+ response = generate_response(request.message, history, retriever)
143
+
144
+ # Update history
145
+ history.append(f"User: {request.message}")
146
+ history.append(f"Assistant: {response}")
147
+
148
+ return {"response": response}
149
+ except Exception as e:
150
+ print(traceback.format_exc())
151
+ raise HTTPException(status_code=500, detail=str(e))