Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -127,78 +127,74 @@ vectorstore = Chroma.from_documents(
|
|
127 |
)
|
128 |
|
129 |
|
|
|
|
|
|
|
|
|
130 |
class RAGPipeline:
|
131 |
-
def __init__(self, vectorstore, model_name="
|
132 |
self.vectorstore = vectorstore
|
133 |
self.model_name = model_name
|
134 |
self.k = k
|
|
|
|
|
135 |
self.retriever = self.vectorstore.as_retriever(
|
136 |
-
search_type="mmr", search_kwargs={"k":
|
137 |
)
|
138 |
self.prompt_template = PromptTemplate.from_template(self._get_template())
|
139 |
|
140 |
-
# Load model and tokenizer
|
141 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=token)
|
142 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
143 |
-
self.model_name, torch_dtype=torch.bfloat16, device_map="auto", token=token
|
144 |
-
)
|
145 |
-
|
146 |
def _get_template(self):
|
147 |
-
return """
|
148 |
-
<s>[INST] <<SYS>>
|
149 |
أنت مساعد مفيد يقدم إجابات باللغة العربية بناءً على السياق المقدم.
|
150 |
- أجب فقط باللغة العربية
|
151 |
- إذا لم تجد إجابة في السياق، قل أنك لا تعرف
|
152 |
- كن دقيقاً وواضحاً في إجاباتك
|
|
|
153 |
<</SYS>>
|
154 |
|
155 |
السياق: {context}
|
156 |
|
157 |
السؤال: {question}
|
158 |
الإجابة: [/INST]\
|
159 |
-
|
|
|
160 |
|
161 |
def generate_response(self, question):
|
162 |
retrieved_docs = self._retrieve_documents(question)
|
163 |
prompt = self._create_prompt(retrieved_docs, question)
|
164 |
-
response = self.
|
165 |
return response
|
166 |
|
167 |
def _retrieve_documents(self, question):
|
168 |
-
start = time.time()
|
169 |
retrieved_docs = self.retriever.invoke(question)
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
print(
|
174 |
-
|
|
|
|
|
175 |
|
176 |
def _create_prompt(self, docs, question):
|
177 |
return self.prompt_template.format(context=docs, question=question)
|
178 |
|
179 |
-
def
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
top_p=0.9,
|
188 |
-
do_sample=True,
|
189 |
-
pad_token_id=self.tokenizer.eos_token_id,
|
190 |
)
|
191 |
-
end = time.time()
|
192 |
-
time_lapsed = end - start
|
193 |
-
print(f"Time lapsed in Generation: {time_lapsed}")
|
194 |
-
|
195 |
-
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
196 |
-
# Extract only the assistant's response after [/INST]
|
197 |
-
return response.split("[/INST]")[-1].strip()
|
198 |
-
|
199 |
|
200 |
-
|
|
|
|
|
|
|
201 |
|
|
|
|
|
202 |
|
203 |
question = st.text_area("أدخل سؤالك هنا")
|
204 |
if st.button("Generate Answer"):
|
|
|
127 |
)
|
128 |
|
129 |
|
130 |
+
import cohere
|
131 |
+
from langchain_core.prompts import PromptTemplate
|
132 |
+
|
133 |
+
|
134 |
class RAGPipeline:
|
135 |
+
def __init__(self, vectorstore, api_key, model_name="c4ai-aya-expanse-8b", k=3):
|
136 |
self.vectorstore = vectorstore
|
137 |
self.model_name = model_name
|
138 |
self.k = k
|
139 |
+
self.api_key = api_key
|
140 |
+
self.client = cohere.Client(api_key) # Initialize the Cohere client
|
141 |
self.retriever = self.vectorstore.as_retriever(
|
142 |
+
search_type="mmr", search_kwargs={"k": 3}
|
143 |
)
|
144 |
self.prompt_template = PromptTemplate.from_template(self._get_template())
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
def _get_template(self):
|
147 |
+
return """<s>[INST] <<SYS>>
|
|
|
148 |
أنت مساعد مفيد يقدم إجابات باللغة العربية بناءً على السياق المقدم.
|
149 |
- أجب فقط باللغة العربية
|
150 |
- إذا لم تجد إجابة في السياق، قل أنك لا تعرف
|
151 |
- كن دقيقاً وواضحاً في إجاباتك
|
152 |
+
-جاوب من السياق حصريا
|
153 |
<</SYS>>
|
154 |
|
155 |
السياق: {context}
|
156 |
|
157 |
السؤال: {question}
|
158 |
الإجابة: [/INST]\
|
159 |
+
|
160 |
+
"""
|
161 |
|
162 |
def generate_response(self, question):
|
163 |
retrieved_docs = self._retrieve_documents(question)
|
164 |
prompt = self._create_prompt(retrieved_docs, question)
|
165 |
+
response = self._generate_response_cohere(prompt)
|
166 |
return response
|
167 |
|
168 |
def _retrieve_documents(self, question):
|
|
|
169 |
retrieved_docs = self.retriever.invoke(question)
|
170 |
+
# print("\n=== المستندات المسترجعة ===")
|
171 |
+
# for i, doc in enumerate(retrieved_docs):
|
172 |
+
# print(f"المستند {i+1}: {doc.page_content}")
|
173 |
+
# print("==========================\n")
|
174 |
+
|
175 |
+
# دمج النصوص المسترجعة في سياق واحد
|
176 |
+
return " ".join([doc.page_content for doc in retrieved_docs])
|
177 |
|
178 |
def _create_prompt(self, docs, question):
|
179 |
return self.prompt_template.format(context=docs, question=question)
|
180 |
|
181 |
+
def _generate_response_cohere(self, prompt):
|
182 |
+
# Call Cohere's generate API
|
183 |
+
response = self.client.generate(
|
184 |
+
model=self.model_name,
|
185 |
+
prompt=prompt,
|
186 |
+
max_tokens=2000, # Adjust token limit based on requirements
|
187 |
+
temperature=0.3, # Control creativity
|
188 |
+
stop_sequences=None,
|
|
|
|
|
|
|
189 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
+
if response.generations:
|
192 |
+
return response.generations[0].text.strip()
|
193 |
+
else:
|
194 |
+
raise Exception("No response generated by Cohere API.")
|
195 |
|
196 |
+
api_key = os.getenv("API_KEY")
|
197 |
+
rag_pipeline = RAGPipeline(vectorstore=vectorstore, api_key=api_key)
|
198 |
|
199 |
question = st.text_area("أدخل سؤالك هنا")
|
200 |
if st.button("Generate Answer"):
|