Mojo3 commited on
Commit
cc4a792
·
verified ·
1 Parent(s): 8660efe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -38
app.py CHANGED
@@ -127,78 +127,74 @@ vectorstore = Chroma.from_documents(
127
  )
128
 
129
 
 
 
 
 
130
  class RAGPipeline:
131
- def __init__(self, vectorstore, model_name="CohereForAI/aya-expanse-8b", k=6):
132
  self.vectorstore = vectorstore
133
  self.model_name = model_name
134
  self.k = k
 
 
135
  self.retriever = self.vectorstore.as_retriever(
136
- search_type="mmr", search_kwargs={"k": self.k}
137
  )
138
  self.prompt_template = PromptTemplate.from_template(self._get_template())
139
 
140
- # Load model and tokenizer
141
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=token)
142
- self.model = AutoModelForCausalLM.from_pretrained(
143
- self.model_name, torch_dtype=torch.bfloat16, device_map="auto", token=token
144
- )
145
-
146
  def _get_template(self):
147
- return """\
148
- <s>[INST] <<SYS>>
149
  أنت مساعد مفيد يقدم إجابات باللغة العربية بناءً على السياق المقدم.
150
  - أجب فقط باللغة العربية
151
  - إذا لم تجد إجابة في السياق، قل أنك لا تعرف
152
  - كن دقيقاً وواضحاً في إجاباتك
 
153
  <</SYS>>
154
 
155
  السياق: {context}
156
 
157
  السؤال: {question}
158
  الإجابة: [/INST]\
159
- """
 
160
 
161
  def generate_response(self, question):
162
  retrieved_docs = self._retrieve_documents(question)
163
  prompt = self._create_prompt(retrieved_docs, question)
164
- response = self._generate_response(prompt)
165
  return response
166
 
167
  def _retrieve_documents(self, question):
168
- start = time.time()
169
  retrieved_docs = self.retriever.invoke(question)
170
- result = {f"doc_{i}": doc.page_content for i, doc in enumerate(retrieved_docs)}
171
- end = time.time()
172
- time_lapsed = end - start
173
- print(f"Time lapsed in Retreival: {time_lapsed}")
174
- return result
 
 
175
 
176
  def _create_prompt(self, docs, question):
177
  return self.prompt_template.format(context=docs, question=question)
178
 
179
- def _generate_response(self, prompt):
180
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
181
-
182
- start = time.time()
183
- outputs = self.model.generate(
184
- inputs.input_ids,
185
- max_new_tokens=1024,
186
- temperature=0.7,
187
- top_p=0.9,
188
- do_sample=True,
189
- pad_token_id=self.tokenizer.eos_token_id,
190
  )
191
- end = time.time()
192
- time_lapsed = end - start
193
- print(f"Time lapsed in Generation: {time_lapsed}")
194
-
195
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
196
- # Extract only the assistant's response after [/INST]
197
- return response.split("[/INST]")[-1].strip()
198
-
199
 
200
- rag_pipeline = RAGPipeline(vectorstore)
 
 
 
201
 
 
 
202
 
203
  question = st.text_area("أدخل سؤالك هنا")
204
  if st.button("Generate Answer"):
 
127
  )
128
 
129
 
130
+ import cohere
131
+ from langchain_core.prompts import PromptTemplate
132
+
133
+
134
  class RAGPipeline:
135
+ def __init__(self, vectorstore, api_key, model_name="c4ai-aya-expanse-8b", k=3):
136
  self.vectorstore = vectorstore
137
  self.model_name = model_name
138
  self.k = k
139
+ self.api_key = api_key
140
+ self.client = cohere.Client(api_key) # Initialize the Cohere client
141
  self.retriever = self.vectorstore.as_retriever(
142
+ search_type="mmr", search_kwargs={"k": 3}
143
  )
144
  self.prompt_template = PromptTemplate.from_template(self._get_template())
145
 
 
 
 
 
 
 
146
  def _get_template(self):
147
+ return """<s>[INST] <<SYS>>
 
148
  أنت مساعد مفيد يقدم إجابات باللغة العربية بناءً على السياق المقدم.
149
  - أجب فقط باللغة العربية
150
  - إذا لم تجد إجابة في السياق، قل أنك لا تعرف
151
  - كن دقيقاً وواضحاً في إجاباتك
152
+ -جاوب من السياق حصريا
153
  <</SYS>>
154
 
155
  السياق: {context}
156
 
157
  السؤال: {question}
158
  الإجابة: [/INST]\
159
+
160
+ """
161
 
162
  def generate_response(self, question):
163
  retrieved_docs = self._retrieve_documents(question)
164
  prompt = self._create_prompt(retrieved_docs, question)
165
+ response = self._generate_response_cohere(prompt)
166
  return response
167
 
168
  def _retrieve_documents(self, question):
 
169
  retrieved_docs = self.retriever.invoke(question)
170
+ # print("\n=== المستندات المسترجعة ===")
171
+ # for i, doc in enumerate(retrieved_docs):
172
+ # print(f"المستند {i+1}: {doc.page_content}")
173
+ # print("==========================\n")
174
+
175
+ # دمج النصوص المسترجعة في سياق واحد
176
+ return " ".join([doc.page_content for doc in retrieved_docs])
177
 
178
  def _create_prompt(self, docs, question):
179
  return self.prompt_template.format(context=docs, question=question)
180
 
181
+ def _generate_response_cohere(self, prompt):
182
+ # Call Cohere's generate API
183
+ response = self.client.generate(
184
+ model=self.model_name,
185
+ prompt=prompt,
186
+ max_tokens=2000, # Adjust token limit based on requirements
187
+ temperature=0.3, # Control creativity
188
+ stop_sequences=None,
 
 
 
189
  )
 
 
 
 
 
 
 
 
190
 
191
+ if response.generations:
192
+ return response.generations[0].text.strip()
193
+ else:
194
+ raise Exception("No response generated by Cohere API.")
195
 
196
+ api_key = os.getenv("API_KEY")
197
+ rag_pipeline = RAGPipeline(vectorstore=vectorstore, api_key=api_key)
198
 
199
  question = st.text_area("أدخل سؤالك هنا")
200
  if st.button("Generate Answer"):