AiActivity commited on
Commit
a6a8ad8
·
verified ·
1 Parent(s): 6a5bbc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -90
app.py CHANGED
@@ -154,30 +154,79 @@ def search_web(query, max_results=5):
154
 
155
  return results[:max_results]
156
 
157
- def generate_response(model, tokenizer, prompt, max_new_tokens=512):
158
- """Generate response using the AI model with proper error handling"""
 
 
 
 
 
159
  try:
160
  # For T5 models
161
  if "t5" in MODEL_ID.lower():
162
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
 
 
 
 
 
 
 
 
 
163
 
164
  with torch.no_grad():
165
  outputs = model.generate(
166
  inputs.input_ids,
167
  max_new_tokens=max_new_tokens,
168
- temperature=0.7,
169
- do_sample=True
 
 
170
  )
171
 
172
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
173
  return response
174
 
175
  # For Phi and other models
176
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  if "phi" in MODEL_ID.lower():
178
- formatted_prompt = f"Instruct: {prompt}\nOutput:"
179
  else:
180
- formatted_prompt = prompt
181
 
182
  inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
183
 
@@ -185,8 +234,9 @@ def generate_response(model, tokenizer, prompt, max_new_tokens=512):
185
  outputs = model.generate(
186
  inputs.input_ids,
187
  max_new_tokens=max_new_tokens,
188
- temperature=0.7,
189
- top_p=0.9,
 
190
  do_sample=True,
191
  pad_token_id=tokenizer.eos_token_id
192
  )
@@ -194,24 +244,75 @@ def generate_response(model, tokenizer, prompt, max_new_tokens=512):
194
  response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
195
 
196
  # Check if response is empty or too short
197
- if not response or len(response) < 10:
198
- # Try again with different parameters
 
 
199
  outputs = model.generate(
200
  inputs.input_ids,
201
  max_new_tokens=max_new_tokens,
202
- num_beams=3, # Use beam search instead
203
  temperature=1.0,
204
  do_sample=False, # Deterministic generation
 
205
  pad_token_id=tokenizer.eos_token_id
206
  )
207
 
208
  response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  return response
 
211
  except Exception as e:
212
- print(f"Error generating response: {e}")
213
- # Return a simple error message
214
- return "I encountered a technical issue while generating a response. Please try another query."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def ensure_citations(text, search_results):
217
  """Ensure citations are properly added to the text"""
@@ -234,90 +335,53 @@ def ensure_citations(text, search_results):
234
 
235
  return text
236
 
237
- def generate_related_topics(model, tokenizer, query, answer):
238
- """Generate related topics based on the AI model"""
239
- try:
240
- # Craft a prompt to generate related topics
241
- related_prompt = f"""Based on the original search query "{query}" and the information in this answer:
242
- "{answer[:300]}...", generate 3 related topics or questions that someone might want to explore next.
243
- Each should be specific and directly related to the query but explore a different aspect.
244
- Format as a simple list with 3 items only."""
245
-
246
- # Use the model to generate topics
247
- related_text = generate_response(model, tokenizer, related_prompt, max_new_tokens=200)
248
-
249
- # Parse the generated text into individual topics
250
- lines = related_text.split('\n')
251
- topics = []
252
-
253
- for line in lines:
254
- # Clean up line by removing numbers, bullet points, etc.
255
- clean_line = re.sub(r'^[\d\-\*\•\.\s]+', '', line.strip())
256
- if clean_line and len(clean_line) > 5:
257
- topics.append(clean_line)
258
-
259
- # Ensure we have at least 3 topics
260
- if len(topics) < 3:
261
- # Add generic but relevant topics based on the query
262
- base_topics = [
263
- f"History of {query}",
264
- f"Latest developments in {query}",
265
- f"How does {query} work?",
266
- f"Applications of {query}",
267
- f"Future of {query}"
268
- ]
269
-
270
- # Add topics until we have at least 3
271
- for topic in base_topics:
272
- if len(topics) >= 3:
273
- break
274
- if topic not in topics:
275
- topics.append(topic)
276
-
277
- return topics[:3] # Return top 3 topics
278
-
279
- except Exception as e:
280
- print(f"Error generating related topics: {e}")
281
- # Return generic topics as fallback
282
- return [
283
- f"More about {query}",
284
- f"Latest developments in {query}",
285
- f"Applications of {query}"
286
- ]
287
-
288
  def process_query(query):
289
- """Main function to process a query with real search and AI responses"""
290
  try:
291
  # Step 1: Search the web for real results
292
  search_results = search_web(query, max_results=5)
293
 
294
- # Step 2: Create context from search results
295
- context = f"Query: {query}\n\nSearch Results:\n\n"
 
296
 
297
  for i, result in enumerate(search_results, 1):
298
- context += f"Source {i}:\n"
299
- context += f"Title: {result['title']}\n"
300
- context += f"URL: {result['url']}\n"
301
- context += f"Content: {result['snippet']}\n\n"
302
 
303
- # Step 3: Create prompt for the AI model
304
- prompt = f"""You are a helpful AI assistant that provides accurate and comprehensive answers based on search results.
305
 
306
  {context}
307
 
308
- Based on these search results, please provide a detailed answer to the query: "{query}"
309
- Include citations like [1], [2], etc. to reference the sources.
310
- Be factual and accurate. If the search results don't contain enough information, acknowledge this limitation.
311
- Format your answer in clear paragraphs with bullet points where appropriate."""
312
 
313
- # Step 4: Generate answer using the AI model
314
- answer = generate_response(model, tokenizer, prompt, max_new_tokens=512)
 
 
315
 
316
- # Step 5: Ensure citations
317
  answer = ensure_citations(answer, search_results)
318
 
319
- # Step 6: Generate related topics using the AI model
320
- related_topics = generate_related_topics(model, tokenizer, query, answer)
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  # Return the complete result
323
  return {
@@ -330,9 +394,9 @@ Format your answer in clear paragraphs with bullet points where appropriate."""
330
  print(f"Error in process_query: {e}")
331
  # Return a minimal result that won't break the UI
332
  return {
333
- "answer": f"I encountered an error while processing your query about '{query}'. Please try again or try a different search term.",
334
- "sources": search_web(query, max_results=2), # Try to get at least some sources
335
- "related_topics": [f"More about {query}", f"Different aspects of {query}", f"Applications of {query}"]
336
  }
337
 
338
  def format_sources(sources):
@@ -455,9 +519,11 @@ def format_related(topics):
455
  observer.observe(document.body, { childList: true, subtree: true });
456
 
457
  // jQuery-like helper function
458
- Element.prototype.contains = function(text) {
459
- return this.innerText.includes(text);
460
- };
 
 
461
  </script>
462
  """
463
 
 
154
 
155
  return results[:max_results]
156
 
157
+ def generate_response(prompt, max_new_tokens=256):
158
+ """Generate response using the AI model with robust fallbacks"""
159
+ # Check if model is loaded properly
160
+ if 'model' not in globals() or model is None:
161
+ print("Model not available for generation")
162
+ return "Based on the search results, I can provide information about this topic. Please check the sources for more detailed information."
163
+
164
  try:
165
  # For T5 models
166
  if "t5" in MODEL_ID.lower():
167
+ # Simplify prompt for T5
168
+ simple_prompt = prompt
169
+ if len(simple_prompt) > 512:
170
+ # Truncate to essential parts for T5
171
+ parts = prompt.split("\n\n")
172
+ query_part = next((p for p in parts if p.startswith("Query:")), "")
173
+ instruction_part = parts[-1] if parts else ""
174
+ simple_prompt = f"{query_part}\n\n{instruction_part}"
175
+
176
+ inputs = tokenizer(simple_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
177
 
178
  with torch.no_grad():
179
  outputs = model.generate(
180
  inputs.input_ids,
181
  max_new_tokens=max_new_tokens,
182
+ temperature=0.8,
183
+ do_sample=True,
184
+ top_k=50,
185
+ repetition_penalty=1.2
186
  )
187
 
188
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
189
+
190
+ # If response is too short, try again with different parameters
191
+ if len(response) < 50:
192
+ outputs = model.generate(
193
+ inputs.input_ids,
194
+ max_new_tokens=max_new_tokens,
195
+ num_beams=4,
196
+ temperature=1.0,
197
+ do_sample=False
198
+ )
199
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
200
+
201
  return response
202
 
203
  # For Phi and other models
204
  else:
205
+ # Extract just the query from the prompt for simpler generation
206
+ query = ""
207
+ search_results_text = ""
208
+
209
+ if "Query:" in prompt:
210
+ query_section = prompt.split("Query:")[1].split("\n")[0].strip()
211
+ query = query_section
212
+ elif "question:" in prompt.lower():
213
+ query_section = prompt.split("question:")[1].split("\n")[0].strip()
214
+ query = query_section
215
+ else:
216
+ # Try to extract from the beginning of the prompt
217
+ query = prompt.split("\n")[0].strip()
218
+
219
+ if "Search Results:" in prompt:
220
+ search_results_text = prompt.split("Search Results:")[1].split("Based on")[0].strip()
221
+
222
+ # Create a simpler prompt format for better results
223
+ simple_prompt = f"Answer this question based on these search results:\n\nQuestion: {query}\n\nSearch Results: {search_results_text[:500]}...\n\nAnswer:"
224
+
225
+ # Adjust format based on model
226
  if "phi" in MODEL_ID.lower():
227
+ formatted_prompt = f"Instruct: {simple_prompt}\nOutput:"
228
  else:
229
+ formatted_prompt = simple_prompt
230
 
231
  inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
232
 
 
234
  outputs = model.generate(
235
  inputs.input_ids,
236
  max_new_tokens=max_new_tokens,
237
+ temperature=0.85,
238
+ top_p=0.92,
239
+ top_k=50,
240
  do_sample=True,
241
  pad_token_id=tokenizer.eos_token_id
242
  )
 
244
  response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
245
 
246
  # Check if response is empty or too short
247
+ if not response or len(response) < 20:
248
+ print("First generation attempt failed, trying alternative method")
249
+
250
+ # Try with different parameters
251
  outputs = model.generate(
252
  inputs.input_ids,
253
  max_new_tokens=max_new_tokens,
254
+ num_beams=3, # Use beam search
255
  temperature=1.0,
256
  do_sample=False, # Deterministic generation
257
+ repetition_penalty=1.2,
258
  pad_token_id=tokenizer.eos_token_id
259
  )
260
 
261
  response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip()
262
 
263
+ # If still no good response, use a minimal reliable response
264
+ if not response or len(response) < 20:
265
+ print("Second generation attempt failed, using fallback response")
266
+
267
+ # Create a simple response that's guaranteed to work
268
+ if query:
269
+ base_response = f"Based on the search results, I can provide information about {query}. "
270
+ base_response += "The sources contain relevant details about this topic. "
271
+ base_response += "You can refer to them for more in-depth information."
272
+ return base_response
273
+ else:
274
+ return "Based on the search results, I can provide information related to your query. Please check the sources for more details."
275
+
276
  return response
277
+
278
  except Exception as e:
279
+ print(f"Error in generate_response: {e}")
280
+ # Return a guaranteed fallback response
281
+ return "Based on the search results, I found information related to your query. The sources listed below contain more detailed information about this topic."
282
+
283
+ def parse_related_topics(text, query):
284
+ """Extract related topics from generated text with better fallbacks"""
285
+ topics = []
286
+
287
+ # Parse lines and clean them up
288
+ lines = text.split('\n')
289
+ for line in lines:
290
+ # Clean up line from numbers and symbols
291
+ clean_line = re.sub(r'^[\d\-\*\•\.\s]+', '', line.strip())
292
+ if clean_line and len(clean_line) > 10:
293
+ # Make sure it ends with a question mark if it seems like a question
294
+ if any(q in clean_line.lower() for q in ['what', 'how', 'why', 'when', 'where', 'who']) and not clean_line.endswith('?'):
295
+ clean_line += '?'
296
+ topics.append(clean_line)
297
+
298
+ # If we don't have enough topics, generate some based on the query
299
+ if len(topics) < 3:
300
+ base_queries = [
301
+ f"What is the history of {query}?",
302
+ f"How does {query} work?",
303
+ f"What are the latest developments in {query}?",
304
+ f"What are common applications of {query}?",
305
+ f"How is {query} used today?"
306
+ ]
307
+
308
+ # Add base queries until we have at least 3
309
+ for bq in base_queries:
310
+ if len(topics) >= 3:
311
+ break
312
+ if not any(bq.lower() in t.lower() for t in topics):
313
+ topics.append(bq)
314
+
315
+ return topics[:3] # Return top 3 topics
316
 
317
  def ensure_citations(text, search_results):
318
  """Ensure citations are properly added to the text"""
 
335
 
336
  return text
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def process_query(query):
339
+ """Main function to process a query with robust response generation"""
340
  try:
341
  # Step 1: Search the web for real results
342
  search_results = search_web(query, max_results=5)
343
 
344
+ # Step 2: Create context from search results - shorter and more focused
345
+ context = f"Query: {query}\n\n"
346
+ context += "Search Results Summary:\n\n"
347
 
348
  for i, result in enumerate(search_results, 1):
349
+ # Use shorter context to avoid token limits
350
+ context += f"Source {i}: {result['title']}\n"
351
+ context += f"Content: {result['snippet'][:150]}\n\n"
 
352
 
353
+ # Step 3: Create a simpler prompt for the AI model
354
+ prompt = f"""Answer this question based on the search results: {query}
355
 
356
  {context}
357
 
358
+ Provide a clear answer using information from these sources. Include citations like [1], [2] to reference sources."""
359
+
360
+ # Step 4: Generate answer using the improved generation function
361
+ answer = generate_response(prompt, max_new_tokens=384)
362
 
363
+ # Step 5: Ensure we have some answer content
364
+ if not answer or len(answer.strip()) < 30:
365
+ print("Fallback to generic response")
366
+ answer = f"Based on the search results for '{query}', I found relevant information in the sources listed below. They provide details about this topic that you may find useful."
367
 
368
+ # Step 6: Ensure citations
369
  answer = ensure_citations(answer, search_results)
370
 
371
+ # Step 7: Generate related topics
372
+ # Use a simpler approach to get related topics since this might be failing too
373
+ try:
374
+ related_prompt = f"Generate 3 questions related to: {query}"
375
+ related_raw = generate_response(related_prompt, max_new_tokens=150)
376
+ related_topics = parse_related_topics(related_raw, query)
377
+ except Exception as e:
378
+ print(f"Error generating related topics: {e}")
379
+ # Fallback topics
380
+ related_topics = [
381
+ f"What is the history of {query}?",
382
+ f"How does {query} work?",
383
+ f"What are applications of {query}?"
384
+ ]
385
 
386
  # Return the complete result
387
  return {
 
394
  print(f"Error in process_query: {e}")
395
  # Return a minimal result that won't break the UI
396
  return {
397
+ "answer": f"I found information about '{query}' in the sources below. They provide details about this topic that may be helpful.",
398
+ "sources": search_results if 'search_results' in locals() else search_web(query, max_results=2),
399
+ "related_topics": [f"What is {query}?", f"History of {query}", f"How to use {query}"]
400
  }
401
 
402
  def format_sources(sources):
 
519
  observer.observe(document.body, { childList: true, subtree: true });
520
 
521
  // jQuery-like helper function
522
+ if (!Element.prototype.contains) {
523
+ Element.prototype.contains = function(text) {
524
+ return this.innerText.includes(text);
525
+ };
526
+ }
527
  </script>
528
  """
529