Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -154,30 +154,79 @@ def search_web(query, max_results=5):
|
|
154 |
|
155 |
return results[:max_results]
|
156 |
|
157 |
-
def generate_response(
|
158 |
-
"""Generate response using the AI model with
|
|
|
|
|
|
|
|
|
|
|
159 |
try:
|
160 |
# For T5 models
|
161 |
if "t5" in MODEL_ID.lower():
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
with torch.no_grad():
|
165 |
outputs = model.generate(
|
166 |
inputs.input_ids,
|
167 |
max_new_tokens=max_new_tokens,
|
168 |
-
temperature=0.
|
169 |
-
do_sample=True
|
|
|
|
|
170 |
)
|
171 |
|
172 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
return response
|
174 |
|
175 |
# For Phi and other models
|
176 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
if "phi" in MODEL_ID.lower():
|
178 |
-
formatted_prompt = f"Instruct: {
|
179 |
else:
|
180 |
-
formatted_prompt =
|
181 |
|
182 |
inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
|
183 |
|
@@ -185,8 +234,9 @@ def generate_response(model, tokenizer, prompt, max_new_tokens=512):
|
|
185 |
outputs = model.generate(
|
186 |
inputs.input_ids,
|
187 |
max_new_tokens=max_new_tokens,
|
188 |
-
temperature=0.
|
189 |
-
top_p=0.
|
|
|
190 |
do_sample=True,
|
191 |
pad_token_id=tokenizer.eos_token_id
|
192 |
)
|
@@ -194,24 +244,75 @@ def generate_response(model, tokenizer, prompt, max_new_tokens=512):
|
|
194 |
response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
|
195 |
|
196 |
# Check if response is empty or too short
|
197 |
-
if not response or len(response) <
|
198 |
-
|
|
|
|
|
199 |
outputs = model.generate(
|
200 |
inputs.input_ids,
|
201 |
max_new_tokens=max_new_tokens,
|
202 |
-
num_beams=3, # Use beam search
|
203 |
temperature=1.0,
|
204 |
do_sample=False, # Deterministic generation
|
|
|
205 |
pad_token_id=tokenizer.eos_token_id
|
206 |
)
|
207 |
|
208 |
response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
return response
|
|
|
211 |
except Exception as e:
|
212 |
-
print(f"Error
|
213 |
-
# Return a
|
214 |
-
return "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
def ensure_citations(text, search_results):
|
217 |
"""Ensure citations are properly added to the text"""
|
@@ -234,90 +335,53 @@ def ensure_citations(text, search_results):
|
|
234 |
|
235 |
return text
|
236 |
|
237 |
-
def generate_related_topics(model, tokenizer, query, answer):
|
238 |
-
"""Generate related topics based on the AI model"""
|
239 |
-
try:
|
240 |
-
# Craft a prompt to generate related topics
|
241 |
-
related_prompt = f"""Based on the original search query "{query}" and the information in this answer:
|
242 |
-
"{answer[:300]}...", generate 3 related topics or questions that someone might want to explore next.
|
243 |
-
Each should be specific and directly related to the query but explore a different aspect.
|
244 |
-
Format as a simple list with 3 items only."""
|
245 |
-
|
246 |
-
# Use the model to generate topics
|
247 |
-
related_text = generate_response(model, tokenizer, related_prompt, max_new_tokens=200)
|
248 |
-
|
249 |
-
# Parse the generated text into individual topics
|
250 |
-
lines = related_text.split('\n')
|
251 |
-
topics = []
|
252 |
-
|
253 |
-
for line in lines:
|
254 |
-
# Clean up line by removing numbers, bullet points, etc.
|
255 |
-
clean_line = re.sub(r'^[\d\-\*\•\.\s]+', '', line.strip())
|
256 |
-
if clean_line and len(clean_line) > 5:
|
257 |
-
topics.append(clean_line)
|
258 |
-
|
259 |
-
# Ensure we have at least 3 topics
|
260 |
-
if len(topics) < 3:
|
261 |
-
# Add generic but relevant topics based on the query
|
262 |
-
base_topics = [
|
263 |
-
f"History of {query}",
|
264 |
-
f"Latest developments in {query}",
|
265 |
-
f"How does {query} work?",
|
266 |
-
f"Applications of {query}",
|
267 |
-
f"Future of {query}"
|
268 |
-
]
|
269 |
-
|
270 |
-
# Add topics until we have at least 3
|
271 |
-
for topic in base_topics:
|
272 |
-
if len(topics) >= 3:
|
273 |
-
break
|
274 |
-
if topic not in topics:
|
275 |
-
topics.append(topic)
|
276 |
-
|
277 |
-
return topics[:3] # Return top 3 topics
|
278 |
-
|
279 |
-
except Exception as e:
|
280 |
-
print(f"Error generating related topics: {e}")
|
281 |
-
# Return generic topics as fallback
|
282 |
-
return [
|
283 |
-
f"More about {query}",
|
284 |
-
f"Latest developments in {query}",
|
285 |
-
f"Applications of {query}"
|
286 |
-
]
|
287 |
-
|
288 |
def process_query(query):
|
289 |
-
"""Main function to process a query with
|
290 |
try:
|
291 |
# Step 1: Search the web for real results
|
292 |
search_results = search_web(query, max_results=5)
|
293 |
|
294 |
-
# Step 2: Create context from search results
|
295 |
-
context = f"Query: {query}\n\
|
|
|
296 |
|
297 |
for i, result in enumerate(search_results, 1):
|
298 |
-
context
|
299 |
-
context += f"
|
300 |
-
context += f"
|
301 |
-
context += f"Content: {result['snippet']}\n\n"
|
302 |
|
303 |
-
# Step 3: Create prompt for the AI model
|
304 |
-
prompt = f"""
|
305 |
|
306 |
{context}
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
# Step
|
314 |
-
answer
|
|
|
|
|
315 |
|
316 |
-
# Step
|
317 |
answer = ensure_citations(answer, search_results)
|
318 |
|
319 |
-
# Step
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
# Return the complete result
|
323 |
return {
|
@@ -330,9 +394,9 @@ Format your answer in clear paragraphs with bullet points where appropriate."""
|
|
330 |
print(f"Error in process_query: {e}")
|
331 |
# Return a minimal result that won't break the UI
|
332 |
return {
|
333 |
-
"answer": f"I
|
334 |
-
"sources": search_web(query, max_results=2),
|
335 |
-
"related_topics": [f"
|
336 |
}
|
337 |
|
338 |
def format_sources(sources):
|
@@ -455,9 +519,11 @@ def format_related(topics):
|
|
455 |
observer.observe(document.body, { childList: true, subtree: true });
|
456 |
|
457 |
// jQuery-like helper function
|
458 |
-
Element.prototype.contains
|
459 |
-
|
460 |
-
|
|
|
|
|
461 |
</script>
|
462 |
"""
|
463 |
|
|
|
154 |
|
155 |
return results[:max_results]
|
156 |
|
157 |
+
def generate_response(prompt, max_new_tokens=256):
|
158 |
+
"""Generate response using the AI model with robust fallbacks"""
|
159 |
+
# Check if model is loaded properly
|
160 |
+
if 'model' not in globals() or model is None:
|
161 |
+
print("Model not available for generation")
|
162 |
+
return "Based on the search results, I can provide information about this topic. Please check the sources for more detailed information."
|
163 |
+
|
164 |
try:
|
165 |
# For T5 models
|
166 |
if "t5" in MODEL_ID.lower():
|
167 |
+
# Simplify prompt for T5
|
168 |
+
simple_prompt = prompt
|
169 |
+
if len(simple_prompt) > 512:
|
170 |
+
# Truncate to essential parts for T5
|
171 |
+
parts = prompt.split("\n\n")
|
172 |
+
query_part = next((p for p in parts if p.startswith("Query:")), "")
|
173 |
+
instruction_part = parts[-1] if parts else ""
|
174 |
+
simple_prompt = f"{query_part}\n\n{instruction_part}"
|
175 |
+
|
176 |
+
inputs = tokenizer(simple_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
|
177 |
|
178 |
with torch.no_grad():
|
179 |
outputs = model.generate(
|
180 |
inputs.input_ids,
|
181 |
max_new_tokens=max_new_tokens,
|
182 |
+
temperature=0.8,
|
183 |
+
do_sample=True,
|
184 |
+
top_k=50,
|
185 |
+
repetition_penalty=1.2
|
186 |
)
|
187 |
|
188 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
189 |
+
|
190 |
+
# If response is too short, try again with different parameters
|
191 |
+
if len(response) < 50:
|
192 |
+
outputs = model.generate(
|
193 |
+
inputs.input_ids,
|
194 |
+
max_new_tokens=max_new_tokens,
|
195 |
+
num_beams=4,
|
196 |
+
temperature=1.0,
|
197 |
+
do_sample=False
|
198 |
+
)
|
199 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
200 |
+
|
201 |
return response
|
202 |
|
203 |
# For Phi and other models
|
204 |
else:
|
205 |
+
# Extract just the query from the prompt for simpler generation
|
206 |
+
query = ""
|
207 |
+
search_results_text = ""
|
208 |
+
|
209 |
+
if "Query:" in prompt:
|
210 |
+
query_section = prompt.split("Query:")[1].split("\n")[0].strip()
|
211 |
+
query = query_section
|
212 |
+
elif "question:" in prompt.lower():
|
213 |
+
query_section = prompt.split("question:")[1].split("\n")[0].strip()
|
214 |
+
query = query_section
|
215 |
+
else:
|
216 |
+
# Try to extract from the beginning of the prompt
|
217 |
+
query = prompt.split("\n")[0].strip()
|
218 |
+
|
219 |
+
if "Search Results:" in prompt:
|
220 |
+
search_results_text = prompt.split("Search Results:")[1].split("Based on")[0].strip()
|
221 |
+
|
222 |
+
# Create a simpler prompt format for better results
|
223 |
+
simple_prompt = f"Answer this question based on these search results:\n\nQuestion: {query}\n\nSearch Results: {search_results_text[:500]}...\n\nAnswer:"
|
224 |
+
|
225 |
+
# Adjust format based on model
|
226 |
if "phi" in MODEL_ID.lower():
|
227 |
+
formatted_prompt = f"Instruct: {simple_prompt}\nOutput:"
|
228 |
else:
|
229 |
+
formatted_prompt = simple_prompt
|
230 |
|
231 |
inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
|
232 |
|
|
|
234 |
outputs = model.generate(
|
235 |
inputs.input_ids,
|
236 |
max_new_tokens=max_new_tokens,
|
237 |
+
temperature=0.85,
|
238 |
+
top_p=0.92,
|
239 |
+
top_k=50,
|
240 |
do_sample=True,
|
241 |
pad_token_id=tokenizer.eos_token_id
|
242 |
)
|
|
|
244 |
response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
|
245 |
|
246 |
# Check if response is empty or too short
|
247 |
+
if not response or len(response) < 20:
|
248 |
+
print("First generation attempt failed, trying alternative method")
|
249 |
+
|
250 |
+
# Try with different parameters
|
251 |
outputs = model.generate(
|
252 |
inputs.input_ids,
|
253 |
max_new_tokens=max_new_tokens,
|
254 |
+
num_beams=3, # Use beam search
|
255 |
temperature=1.0,
|
256 |
do_sample=False, # Deterministic generation
|
257 |
+
repetition_penalty=1.2,
|
258 |
pad_token_id=tokenizer.eos_token_id
|
259 |
)
|
260 |
|
261 |
response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
|
262 |
|
263 |
+
# If still no good response, use a minimal reliable response
|
264 |
+
if not response or len(response) < 20:
|
265 |
+
print("Second generation attempt failed, using fallback response")
|
266 |
+
|
267 |
+
# Create a simple response that's guaranteed to work
|
268 |
+
if query:
|
269 |
+
base_response = f"Based on the search results, I can provide information about {query}. "
|
270 |
+
base_response += "The sources contain relevant details about this topic. "
|
271 |
+
base_response += "You can refer to them for more in-depth information."
|
272 |
+
return base_response
|
273 |
+
else:
|
274 |
+
return "Based on the search results, I can provide information related to your query. Please check the sources for more details."
|
275 |
+
|
276 |
return response
|
277 |
+
|
278 |
except Exception as e:
|
279 |
+
print(f"Error in generate_response: {e}")
|
280 |
+
# Return a guaranteed fallback response
|
281 |
+
return "Based on the search results, I found information related to your query. The sources listed below contain more detailed information about this topic."
|
282 |
+
|
283 |
+
def parse_related_topics(text, query):
|
284 |
+
"""Extract related topics from generated text with better fallbacks"""
|
285 |
+
topics = []
|
286 |
+
|
287 |
+
# Parse lines and clean them up
|
288 |
+
lines = text.split('\n')
|
289 |
+
for line in lines:
|
290 |
+
# Clean up line from numbers and symbols
|
291 |
+
clean_line = re.sub(r'^[\d\-\*\•\.\s]+', '', line.strip())
|
292 |
+
if clean_line and len(clean_line) > 10:
|
293 |
+
# Make sure it ends with a question mark if it seems like a question
|
294 |
+
if any(q in clean_line.lower() for q in ['what', 'how', 'why', 'when', 'where', 'who']) and not clean_line.endswith('?'):
|
295 |
+
clean_line += '?'
|
296 |
+
topics.append(clean_line)
|
297 |
+
|
298 |
+
# If we don't have enough topics, generate some based on the query
|
299 |
+
if len(topics) < 3:
|
300 |
+
base_queries = [
|
301 |
+
f"What is the history of {query}?",
|
302 |
+
f"How does {query} work?",
|
303 |
+
f"What are the latest developments in {query}?",
|
304 |
+
f"What are common applications of {query}?",
|
305 |
+
f"How is {query} used today?"
|
306 |
+
]
|
307 |
+
|
308 |
+
# Add base queries until we have at least 3
|
309 |
+
for bq in base_queries:
|
310 |
+
if len(topics) >= 3:
|
311 |
+
break
|
312 |
+
if not any(bq.lower() in t.lower() for t in topics):
|
313 |
+
topics.append(bq)
|
314 |
+
|
315 |
+
return topics[:3] # Return top 3 topics
|
316 |
|
317 |
def ensure_citations(text, search_results):
|
318 |
"""Ensure citations are properly added to the text"""
|
|
|
335 |
|
336 |
return text
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
def process_query(query):
|
339 |
+
"""Main function to process a query with robust response generation"""
|
340 |
try:
|
341 |
# Step 1: Search the web for real results
|
342 |
search_results = search_web(query, max_results=5)
|
343 |
|
344 |
+
# Step 2: Create context from search results - shorter and more focused
|
345 |
+
context = f"Query: {query}\n\n"
|
346 |
+
context += "Search Results Summary:\n\n"
|
347 |
|
348 |
for i, result in enumerate(search_results, 1):
|
349 |
+
# Use shorter context to avoid token limits
|
350 |
+
context += f"Source {i}: {result['title']}\n"
|
351 |
+
context += f"Content: {result['snippet'][:150]}\n\n"
|
|
|
352 |
|
353 |
+
# Step 3: Create a simpler prompt for the AI model
|
354 |
+
prompt = f"""Answer this question based on the search results: {query}
|
355 |
|
356 |
{context}
|
357 |
|
358 |
+
Provide a clear answer using information from these sources. Include citations like [1], [2] to reference sources."""
|
359 |
+
|
360 |
+
# Step 4: Generate answer using the improved generation function
|
361 |
+
answer = generate_response(prompt, max_new_tokens=384)
|
362 |
|
363 |
+
# Step 5: Ensure we have some answer content
|
364 |
+
if not answer or len(answer.strip()) < 30:
|
365 |
+
print("Fallback to generic response")
|
366 |
+
answer = f"Based on the search results for '{query}', I found relevant information in the sources listed below. They provide details about this topic that you may find useful."
|
367 |
|
368 |
+
# Step 6: Ensure citations
|
369 |
answer = ensure_citations(answer, search_results)
|
370 |
|
371 |
+
# Step 7: Generate related topics
|
372 |
+
# Use a simpler approach to get related topics since this might be failing too
|
373 |
+
try:
|
374 |
+
related_prompt = f"Generate 3 questions related to: {query}"
|
375 |
+
related_raw = generate_response(related_prompt, max_new_tokens=150)
|
376 |
+
related_topics = parse_related_topics(related_raw, query)
|
377 |
+
except Exception as e:
|
378 |
+
print(f"Error generating related topics: {e}")
|
379 |
+
# Fallback topics
|
380 |
+
related_topics = [
|
381 |
+
f"What is the history of {query}?",
|
382 |
+
f"How does {query} work?",
|
383 |
+
f"What are applications of {query}?"
|
384 |
+
]
|
385 |
|
386 |
# Return the complete result
|
387 |
return {
|
|
|
394 |
print(f"Error in process_query: {e}")
|
395 |
# Return a minimal result that won't break the UI
|
396 |
return {
|
397 |
+
"answer": f"I found information about '{query}' in the sources below. They provide details about this topic that may be helpful.",
|
398 |
+
"sources": search_results if 'search_results' in locals() else search_web(query, max_results=2),
|
399 |
+
"related_topics": [f"What is {query}?", f"History of {query}", f"How to use {query}"]
|
400 |
}
|
401 |
|
402 |
def format_sources(sources):
|
|
|
519 |
observer.observe(document.body, { childList: true, subtree: true });
|
520 |
|
521 |
// jQuery-like helper function
|
522 |
+
if (!Element.prototype.contains) {
|
523 |
+
Element.prototype.contains = function(text) {
|
524 |
+
return this.innerText.includes(text);
|
525 |
+
};
|
526 |
+
}
|
527 |
</script>
|
528 |
"""
|
529 |
|