Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
import requests | |
import re | |
import time | |
import json | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from bs4 import BeautifulSoup | |
import urllib.parse | |
from markdown import markdown | |
# Set environment variables | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
print("Loading model... Please wait...") | |
# Load the model with proper error handling | |
try: | |
# Try with Phi-2 | |
MODEL_ID = "microsoft/phi-2" | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
print("Successfully loaded Phi-2 model") | |
except Exception as e: | |
print(f"Error loading Phi-2: {e}") | |
print("Trying fallback model...") | |
try: | |
# Fallback to FLAN-T5-base | |
MODEL_ID = "google/flan-t5-base" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
from transformers import T5ForConditionalGeneration | |
model = T5ForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Successfully loaded fallback model") | |
except Exception as e: | |
print(f"Error loading fallback model: {e}") | |
print("Operating in reduced functionality mode") | |
def search_web(query, max_results=5): | |
"""Perform real web searches using multiple search endpoints""" | |
results = [] | |
# Try multiple search methods for reliability | |
# Method 1: Wikipedia API | |
try: | |
wiki_url = f"https://en.wikipedia.org/w/api.php?action=opensearch&search={urllib.parse.quote(query)}&limit={max_results}&namespace=0&format=json" | |
response = requests.get(wiki_url, timeout=5) | |
if response.status_code == 200: | |
data = response.json() | |
titles = data[1] | |
urls = data[3] | |
for i in range(min(len(titles), len(urls))): | |
# Get summary for each page | |
page_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro&explaintext&titles={urllib.parse.quote(titles[i])}&format=json" | |
page_response = requests.get(page_url, timeout=5) | |
if page_response.status_code == 200: | |
page_data = page_response.json() | |
try: | |
page_id = next(iter(page_data['query']['pages'].keys())) | |
if page_id != "-1": | |
extract = page_data['query']['pages'][page_id].get('extract', '') | |
snippet = extract[:200] + "..." if len(extract) > 200 else extract | |
results.append({ | |
'title': f"Wikipedia - {titles[i]}", | |
'url': urls[i], | |
'snippet': snippet | |
}) | |
except Exception as e: | |
print(f"Error extracting wiki data: {e}") | |
continue | |
except Exception as e: | |
print(f"Wikipedia search error: {e}") | |
# Method 2: Public Search API (SerpAPI demo) | |
if len(results) < max_results: | |
try: | |
serpapi_url = f"https://serpapi.com/search.json?engine=google&q={urllib.parse.quote(query)}&api_key=demo" | |
response = requests.get(serpapi_url, timeout=5) | |
if response.status_code == 200: | |
data = response.json() | |
if "organic_results" in data: | |
for result in data["organic_results"][:max_results - len(results)]: | |
results.append({ | |
'title': result.get('title', ''), | |
'url': result.get('link', ''), | |
'snippet': result.get('snippet', '') | |
}) | |
except Exception as e: | |
print(f"SerpAPI error: {e}") | |
# Method 3: Direct web scraping (as last resort) | |
if len(results) < max_results: | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
url = f"https://www.bing.com/search?q={urllib.parse.quote(query)}" | |
response = requests.get(url, headers=headers, timeout=10) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
search_results = soup.find_all('li', class_='b_algo') | |
for result in search_results[:max_results - len(results)]: | |
title_elem = result.find('h2') | |
if title_elem and title_elem.find('a'): | |
title = title_elem.text | |
url = title_elem.find('a')['href'] | |
snippet_elem = result.find('div', class_='b_caption') | |
snippet = snippet_elem.find('p').text if snippet_elem and snippet_elem.find('p') else "" | |
results.append({ | |
'title': title, | |
'url': url, | |
'snippet': snippet | |
}) | |
except Exception as e: | |
print(f"Web scraping error: {e}") | |
# If we still don't have results, create minimal placeholder results | |
# This ensures the UI doesn't break if all search methods fail | |
if not results: | |
results = [ | |
{ | |
'title': f"Search: {query}", | |
'url': f"https://www.google.com/search?q={urllib.parse.quote(query)}", | |
'snippet': "Search engine results for your query." | |
} | |
] | |
return results[:max_results] | |
def generate_response(prompt, max_new_tokens=256): | |
"""Generate response using the AI model with robust fallbacks""" | |
# Check if model is loaded properly | |
if 'model' not in globals() or model is None: | |
print("Model not available for generation") | |
response = f"Based on the search results for '{query}', I can provide the following information:\n\n" | |
# Extract key information from search results | |
for i, result in enumerate(search_results[:3], 1): | |
# Add a section for each source with actual content | |
title = result['title'].replace("Wikipedia - ", "") | |
content = result['snippet'] | |
response += f"**{title}**: {content} [{i}]\n\n" | |
# Add a conclusion | |
response += f"These sources provide information about {query} from different perspectives. For more detailed information, you can explore the full sources listed below." | |
return response | |
try: | |
# For T5 models | |
if "t5" in MODEL_ID.lower(): | |
# Simplify prompt for T5 | |
simple_prompt = prompt | |
if len(simple_prompt) > 512: | |
# Truncate to essential parts for T5 | |
parts = prompt.split("\n\n") | |
query_part = next((p for p in parts if p.startswith("Query:")), "") | |
instruction_part = parts[-1] if parts else "" | |
simple_prompt = f"{query_part}\n\n{instruction_part}" | |
inputs = tokenizer(simple_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=0.8, | |
do_sample=True, | |
top_k=50, | |
repetition_penalty=1.2 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# If response is too short, try again with different parameters | |
if len(response) < 50: | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
num_beams=4, | |
temperature=1.0, | |
do_sample=False | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# For Phi and other models | |
else: | |
# Extract just the query from the prompt for simpler generation | |
query = "" | |
search_results_text = "" | |
if "Query:" in prompt: | |
query_section = prompt.split("Query:")[1].split("\n")[0].strip() | |
query = query_section | |
elif "question:" in prompt.lower(): | |
query_section = prompt.split("question:")[1].split("\n")[0].strip() | |
query = query_section | |
else: | |
# Try to extract from the beginning of the prompt | |
query = prompt.split("\n")[0].strip() | |
if "Search Results:" in prompt: | |
search_results_text = prompt.split("Search Results:")[1].split("Based on")[0].strip() | |
# Create a simpler prompt format for better results | |
simple_prompt = f"Answer this question based on these search results:\n\nQuestion: {query}\n\nSearch Results: {search_results_text[:500]}...\n\nAnswer:" | |
# Adjust format based on model | |
if "phi" in MODEL_ID.lower(): | |
formatted_prompt = f"Instruct: {simple_prompt}\nOutput:" | |
else: | |
formatted_prompt = simple_prompt | |
inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=0.85, | |
top_p=0.92, | |
top_k=50, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip() | |
# Check if response is empty or too short | |
if not response or len(response) < 20: | |
print("First generation attempt failed, trying alternative method") | |
# Try with different parameters | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_new_tokens, | |
num_beams=3, # Use beam search | |
temperature=1.0, | |
do_sample=False, # Deterministic generation | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0][inputs.input_ids.size(1):], skip_special_tokens=True).strip() | |
# If still no good response, use a minimal reliable response | |
if not response or len(response) < 20: | |
print("Second generation attempt failed, using fallback response") | |
# Create a simple response that's guaranteed to work | |
if query: | |
base_response = f"Based on the search results, I can provide information about {query}. " | |
base_response += "The sources contain relevant details about this topic. " | |
base_response += "You can refer to them for more in-depth information." | |
return base_response | |
else: | |
return "Based on the search results, I can provide information related to your query. Please check the sources for more details." | |
return response | |
except Exception as e: | |
print(f"Error in generate_response: {e}") | |
# Return a guaranteed fallback response | |
return "Based on the search results, I found information related to your query. The sources listed below contain more detailed information about this topic." | |
def parse_related_topics(text, query): | |
"""Extract related topics from generated text with better fallbacks""" | |
topics = [] | |
# Parse lines and clean them up | |
lines = text.split('\n') | |
for line in lines: | |
# Clean up line from numbers and symbols | |
clean_line = re.sub(r'^[\d\-\*\β’\.\s]+', '', line.strip()) | |
if clean_line and len(clean_line) > 10: | |
# Make sure it ends with a question mark if it seems like a question | |
if any(q in clean_line.lower() for q in ['what', 'how', 'why', 'when', 'where', 'who']) and not clean_line.endswith('?'): | |
clean_line += '?' | |
topics.append(clean_line) | |
# If we don't have enough topics, generate some based on the query | |
if len(topics) < 3: | |
base_queries = [ | |
f"What is the history of {query}?", | |
f"How does {query} work?", | |
f"What are the latest developments in {query}?", | |
f"What are common applications of {query}?", | |
f"How is {query} used today?" | |
] | |
# Add base queries until we have at least 3 | |
for bq in base_queries: | |
if len(topics) >= 3: | |
break | |
if not any(bq.lower() in t.lower() for t in topics): | |
topics.append(bq) | |
return topics[:3] # Return top 3 topics | |
def ensure_citations(text, search_results): | |
"""Ensure citations are properly added to the text""" | |
# If text is too short, return a generic message | |
if not text or len(text.strip()) < 10: | |
return "I couldn't generate a proper response for this query. Please try a different search term." | |
# Add citations if not present | |
if not re.search(r'\[\d+\]', text): | |
# Try to find snippets in the answer | |
for i, result in enumerate(search_results, 1): | |
key_phrases = result['snippet'].split('.') | |
for phrase in key_phrases: | |
if phrase and len(phrase) > 15 and phrase.strip() in text: | |
text = text.replace(phrase, f"{phrase} [{i}]", 1) | |
# If still no citations, add a generic one at the end | |
if not re.search(r'\[\d+\]', text): | |
text += f" [{1}]" | |
return text | |
def process_query(query): | |
"""Main function to process a query with robust response generation""" | |
try: | |
# Step 1: Search the web for real results | |
search_results = search_web(query, max_results=5) | |
# Step 2: Create context from search results - shorter and more focused | |
context = f"Query: {query}\n\n" | |
context += "Search Results Summary:\n\n" | |
for i, result in enumerate(search_results, 1): | |
# Use shorter context to avoid token limits | |
context += f"Source {i}: {result['title']}\n" | |
context += f"Content: {result['snippet'][:150]}\n\n" | |
# Step 3: Create a simpler prompt for the AI model | |
prompt = f"""Answer this question based on the search results: {query} | |
{context} | |
Provide a clear answer using information from these sources. Include citations like [1], [2] to reference sources.""" | |
# Step 4: Generate answer using the improved generation function | |
answer = generate_response(prompt, max_new_tokens=384) | |
# Step 5: Ensure we have some answer content | |
if not answer or len(answer.strip()) < 30: | |
print("Fallback to generic response") | |
answer = f"Based on the search results for '{query}', I found relevant information in the sources listed below. They provide details about this topic that you may find useful." | |
# Step 6: Ensure citations | |
answer = ensure_citations(answer, search_results) | |
# Step 7: Generate related topics | |
# Use a simpler approach to get related topics since this might be failing too | |
try: | |
related_prompt = f"Generate 3 questions related to: {query}" | |
related_raw = generate_response(related_prompt, max_new_tokens=150) | |
related_topics = parse_related_topics(related_raw, query) | |
except Exception as e: | |
print(f"Error generating related topics: {e}") | |
# Fallback topics | |
related_topics = [ | |
f"What is the history of {query}?", | |
f"How does {query} work?", | |
f"What are applications of {query}?" | |
] | |
# Return the complete result | |
return { | |
"answer": answer, | |
"sources": search_results, | |
"related_topics": related_topics | |
} | |
except Exception as e: | |
print(f"Error in process_query: {e}") | |
# Return a minimal result that won't break the UI | |
return { | |
"answer": f"I found information about '{query}' in the sources below. They provide details about this topic that may be helpful.", | |
"sources": search_results if 'search_results' in locals() else search_web(query, max_results=2), | |
"related_topics": [f"What is {query}?", f"History of {query}", f"How to use {query}"] | |
} | |
def format_sources(sources): | |
"""Format sources for display""" | |
if not sources: | |
return "" | |
html = "" | |
for i, source in enumerate(sources, 1): | |
html += f""" | |
<div style="margin-bottom: 15px; padding: 15px; background-color: #FFFFFF; | |
border-radius: 12px; border-left: 4px solid #2563EB; box-shadow: 0 2px 6px rgba(0,0,0,0.08);"> | |
<a href="{source['url']}" target="_blank" style="font-weight: 600; | |
color: #2563EB; text-decoration: none; font-size: 16px;"> | |
{source['title']} | |
</a> | |
<div style="color: #64748B; font-size: 14px; margin-top: 6px;">{source['url']}</div> | |
<div style="margin-top: 10px; color: #374151; line-height: 1.5;">{source['snippet']}</div> | |
</div> | |
""" | |
return html | |
def format_related(topics): | |
"""Format related topics for display with reliable click handlers""" | |
if not topics: | |
return "" | |
# Create HTML with unique IDs for each topic | |
html = "<div style='display: flex; flex-wrap: wrap; gap: 10px; margin-top: 15px;'>" | |
for i, topic in enumerate(topics): | |
# Each topic is a button with a unique ID | |
html += f""" | |
<div id="topic-{i}" style="background-color: #EFF6FF; padding: 10px 16px; border-radius: 100px; | |
color: #2563EB; font-size: 14px; font-weight: 500; cursor: pointer; display: inline-block; | |
transition: all 0.2s ease; border: 1px solid #DBEAFE; box-shadow: 0 1px 2px rgba(0,0,0,0.05);" | |
data-topic="{topic}" | |
onmouseover="this.style.backgroundColor='#DBEAFE'; this.style.boxShadow='0 2px 5px rgba(0,0,0,0.1)';" | |
onmouseout="this.style.backgroundColor='#EFF6FF'; this.style.boxShadow='0 1px 2px rgba(0,0,0,0.05)';"> | |
{topic} | |
</div> | |
""" | |
html += "</div>" | |
# Add JavaScript to handle topic clicks | |
html += """ | |
<script> | |
// Set up event listeners for topic clicks | |
function setupTopicClicks() { | |
// Find all topic elements | |
const topics = document.querySelectorAll('[id^="topic-"]'); | |
// Add click listeners to each topic | |
topics.forEach(topic => { | |
topic.addEventListener('click', function() { | |
// Get the topic text | |
const topicText = this.getAttribute('data-topic'); | |
console.log("Clicked topic:", topicText); | |
// Set input value to the topic text | |
const inputElement = document.getElementById('query-input'); | |
if (inputElement) { | |
inputElement.value = topicText; | |
// Try multiple methods to trigger the search | |
// Method 1: Click the search button | |
const searchButton = document.querySelector('button[data-testid="submit"]'); | |
if (searchButton) { | |
searchButton.click(); | |
return; | |
} | |
// Method 2: Try other button selectors | |
const altButton = document.querySelector('button[aria-label="Submit"]') || | |
document.querySelector('button:contains("Search")'); | |
if (altButton) { | |
altButton.click(); | |
return; | |
} | |
// Method 3: Find button by text content | |
const buttons = Array.from(document.querySelectorAll('button')); | |
const searchBtn = buttons.find(btn => | |
btn.textContent.includes('Search') || | |
btn.innerHTML.includes('Search') | |
); | |
if (searchBtn) { | |
searchBtn.click(); | |
return; | |
} | |
// Method 4: Trigger form submission directly | |
const form = inputElement.closest('form'); | |
if (form) { | |
const event = new Event('submit', { bubbles: true }); | |
form.dispatchEvent(event); | |
return; | |
} | |
console.log("Could not find a way to trigger search"); | |
} | |
}); | |
}); | |
} | |
// Run the setup function | |
setupTopicClicks(); | |
// Set up an observer to handle dynamically loaded topics | |
const observer = new MutationObserver(function(mutations) { | |
mutations.forEach(function(mutation) { | |
if (mutation.addedNodes.length) { | |
setupTopicClicks(); | |
} | |
}); | |
}); | |
// Start observing the document | |
observer.observe(document.body, { childList: true, subtree: true }); | |
// jQuery-like helper function | |
if (!Element.prototype.contains) { | |
Element.prototype.contains = function(text) { | |
return this.innerText.includes(text); | |
}; | |
} | |
</script> | |
""" | |
return html | |
def search_interface(query): | |
"""Main function for the Gradio interface with progress updates""" | |
if not query.strip(): | |
return ( | |
"Please enter a search query.", | |
"", | |
"" | |
) | |
start_time = time.time() | |
try: | |
# Show loading message while processing | |
yield ("Searching and generating response...", "", "") | |
# Process the query | |
result = process_query(query) | |
# Format answer with markdown | |
answer_html = markdown(result["answer"]) | |
# Format sources | |
sources_html = format_sources(result["sources"]) | |
# Format related topics | |
related_html = format_related(result["related_topics"]) | |
# Calculate processing time | |
processing_time = time.time() - start_time | |
print(f"Query processed in {processing_time:.2f} seconds") | |
yield ( | |
answer_html, | |
sources_html, | |
related_html | |
) | |
except Exception as e: | |
print(f"Error in search_interface: {e}") | |
# Return a fallback response | |
yield ( | |
"I encountered an issue while processing your query. Please try again with a different search term.", | |
"", | |
"" | |
) | |
# Create the Gradio interface with modern UI | |
css = """ | |
/* Global styles */ | |
body { | |
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; | |
background-color: #F9FAFB; | |
color: #1F2937; | |
line-height: 1.6; | |
} | |
/* Container styling */ | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 0 20px; | |
} | |
/* Header styling */ | |
.header { | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
/* Search box styling */ | |
#search-container input { | |
border: 1px solid #E5E7EB; | |
border-radius: 12px; | |
padding: 12px 20px; | |
font-size: 16px; | |
box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
transition: all 0.2s ease; | |
} | |
#search-container input:focus { | |
border-color: #2563EB; | |
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.2); | |
outline: none; | |
} | |
/* Button styling */ | |
button[data-testid="submit"] { | |
background-color: #2563EB !important; | |
color: white !important; | |
font-weight: 600 !important; | |
border-radius: 12px !important; | |
padding: 12px 24px !important; | |
border: none !important; | |
cursor: pointer !important; | |
transition: all 0.2s ease !important; | |
box-shadow: 0 2px 5px rgba(37, 99, 235, 0.3) !important; | |
} | |
button[data-testid="submit"]:hover { | |
background-color: #1D4ED8 !important; | |
box-shadow: 0 4px 8px rgba(37, 99, 235, 0.4) !important; | |
transform: translateY(-1px) !important; | |
} | |
/* Section headers */ | |
h3 { | |
color: #2563EB; | |
font-weight: 600; | |
margin-top: 2rem; | |
margin-bottom: 1rem; | |
font-size: 1.25rem; | |
border-bottom: 2px solid #DBEAFE; | |
padding-bottom: 0.5rem; | |
} | |
/* Answer box styling */ | |
.answer { | |
background-color: #FFFFFF; | |
padding: 24px; | |
border-radius: 12px; | |
box-shadow: 0 2px 6px rgba(0,0,0,0.05); | |
border: 1px solid #E5E7EB; | |
line-height: 1.7; | |
margin-bottom: 1.5rem; | |
color: #374151; | |
min-height: 100px; | |
} | |
.answer p { | |
margin-bottom: 1rem; | |
color: #1F2937; | |
} | |
.answer ul, .answer ol { | |
margin-left: 1.5rem; | |
margin-bottom: 1rem; | |
} | |
.answer strong, .answer b { | |
color: #111827; | |
font-weight: 600; | |
} | |
.answer a { | |
color: #2563EB; | |
text-decoration: none; | |
border-bottom: 1px solid currentColor; | |
} | |
/* Loading state */ | |
.answer.loading { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
} | |
/* Footer styling */ | |
footer { | |
margin-top: 2rem; | |
text-align: center; | |
color: #6B7280; | |
font-size: 0.875rem; | |
padding: 1rem 0; | |
} | |
/* Responsive styles */ | |
@media (max-width: 768px) { | |
.answer { | |
padding: 16px; | |
} | |
button[data-testid="submit"] { | |
padding: 10px 16px !important; | |
} | |
} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo: | |
# Custom header with professional design | |
gr.HTML(""" | |
<div class="header"> | |
<h1 style="color: #2563EB; font-size: 2.2rem; font-weight: 700; margin-bottom: 0.5rem;">π AI Search System</h1> | |
<p style="color: #64748B; font-size: 1.1rem; max-width: 600px; margin: 0 auto;"> | |
Get comprehensive answers with real sources for any question. | |
</p> | |
</div> | |
""") | |
# Search container with improved styling | |
with gr.Row(elem_id="search-container"): | |
query_input = gr.Textbox( | |
label="Search Query", | |
placeholder="What would you like to know?", | |
elem_id="query-input", | |
scale=4 | |
) | |
search_button = gr.Button("Search π", variant="primary", scale=1) | |
# Results container with improved layout | |
with gr.Row(): | |
# Left column for answer and related topics | |
with gr.Column(scale=2): | |
# Answer section with better styling | |
gr.HTML("<h3>π Answer</h3>") | |
answer_output = gr.HTML(elem_classes=["answer"]) | |
# Related topics with better styling | |
gr.HTML("<h3>π Related Topics</h3>") | |
related_output = gr.HTML() | |
# Right column for sources | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π Sources</h3>") | |
sources_output = gr.HTML() | |
# Set up event handlers with progress indicators | |
search_button.click( | |
fn=search_interface, | |
inputs=[query_input], | |
outputs=[answer_output, sources_output, related_output] | |
) | |
query_input.submit( | |
fn=search_interface, | |
inputs=[query_input], | |
outputs=[answer_output, sources_output, related_output] | |
) | |
# Footer with attribution | |
gr.HTML(""" | |
<footer> | |
<p>Built with Hugging Face Spaces</p> | |
</footer> | |
""") | |
# Launch app with queue for better performance | |
demo.queue(max_size=10) | |
demo.launch() |