Spaces:
Running
Running
""" | |
Enhanced Web Research Tool for GAIA Agent | |
Integrates with Exa API for advanced web search capabilities | |
""" | |
import os | |
import logging | |
import asyncio | |
from typing import Dict, List, Any, Optional, Union | |
from dataclasses import dataclass | |
from datetime import datetime, timedelta | |
import json | |
import re | |
try: | |
from exa_py import Exa | |
EXA_AVAILABLE = True | |
except ImportError: | |
EXA_AVAILABLE = False | |
try: | |
import requests | |
from bs4 import BeautifulSoup | |
WEB_SCRAPING_AVAILABLE = True | |
except ImportError: | |
WEB_SCRAPING_AVAILABLE = False | |
logger = logging.getLogger(__name__) | |
class SearchResult: | |
"""Structured search result with metadata.""" | |
title: str | |
url: str | |
content: str | |
score: float | |
source: str | |
published_date: Optional[str] = None | |
author: Optional[str] = None | |
domain: str = "" | |
def __post_init__(self): | |
if self.url and not self.domain: | |
try: | |
from urllib.parse import urlparse | |
self.domain = urlparse(self.url).netloc | |
except: | |
self.domain = "unknown" | |
class SearchQuery: | |
"""Structured search query with parameters.""" | |
query: str | |
query_type: str = "general" # general, factual, biographical, historical, technical | |
time_range: Optional[str] = None # recent, year, month, week | |
num_results: int = 10 | |
include_domains: Optional[List[str]] = None | |
exclude_domains: Optional[List[str]] = None | |
require_date: bool = False | |
class EnhancedWebSearchTool: | |
""" | |
Enhanced web search tool with multiple search strategies and result ranking. | |
Features: | |
- Exa API integration for semantic search | |
- Multi-source search aggregation | |
- Result ranking and relevance scoring | |
- Fallback search strategies | |
- Content extraction and summarization | |
""" | |
def __init__(self, exa_api_key: Optional[str] = None): | |
"""Initialize the enhanced web search tool.""" | |
self.exa_api_key = exa_api_key or os.getenv("EXA_API_KEY") | |
self.exa_client = None | |
if self.exa_api_key and EXA_AVAILABLE: | |
try: | |
self.exa_client = Exa(api_key=self.exa_api_key) | |
logger.info("✅ Exa API client initialized successfully") | |
except Exception as e: | |
logger.warning(f"⚠️ Failed to initialize Exa client: {e}") | |
else: | |
logger.warning("⚠️ Exa API not available - check API key and dependencies") | |
# Initialize fallback search capabilities | |
self.fallback_available = WEB_SCRAPING_AVAILABLE | |
# Search result cache for efficiency | |
self._cache = {} | |
self._cache_ttl = 3600 # 1 hour cache | |
def search(self, query: Union[str, SearchQuery], **kwargs) -> List[SearchResult]: | |
""" | |
Perform enhanced web search with multiple strategies. | |
Args: | |
query: Search query string or SearchQuery object | |
**kwargs: Additional search parameters | |
Returns: | |
List of SearchResult objects ranked by relevance | |
""" | |
# Convert string query to SearchQuery object | |
if isinstance(query, str): | |
search_query = SearchQuery( | |
query=query, | |
query_type=kwargs.get('query_type', 'general'), | |
time_range=kwargs.get('time_range'), | |
num_results=kwargs.get('num_results', 10), | |
include_domains=kwargs.get('include_domains'), | |
exclude_domains=kwargs.get('exclude_domains'), | |
require_date=kwargs.get('require_date', False) | |
) | |
else: | |
search_query = query | |
logger.info(f"🔍 Searching: {search_query.query}") | |
# Check cache first | |
cache_key = self._get_cache_key(search_query) | |
if cache_key in self._cache: | |
cache_entry = self._cache[cache_key] | |
if datetime.now() - cache_entry['timestamp'] < timedelta(seconds=self._cache_ttl): | |
logger.info("📋 Returning cached results") | |
return cache_entry['results'] | |
results = [] | |
# Primary search: Exa API | |
if self.exa_client: | |
try: | |
exa_results = self._search_with_exa(search_query) | |
results.extend(exa_results) | |
logger.info(f"✅ Exa search returned {len(exa_results)} results") | |
except Exception as e: | |
logger.warning(f"⚠️ Exa search failed: {e}") | |
# Fallback search strategies | |
if len(results) < search_query.num_results // 2: | |
try: | |
fallback_results = self._fallback_search(search_query) | |
results.extend(fallback_results) | |
logger.info(f"✅ Fallback search returned {len(fallback_results)} results") | |
except Exception as e: | |
logger.warning(f"⚠️ Fallback search failed: {e}") | |
# Rank and filter results | |
ranked_results = self._rank_results(results, search_query) | |
# Cache results | |
self._cache[cache_key] = { | |
'results': ranked_results, | |
'timestamp': datetime.now() | |
} | |
logger.info(f"🎯 Returning {len(ranked_results)} ranked results") | |
return ranked_results | |
def _search_with_exa(self, search_query: SearchQuery) -> List[SearchResult]: | |
"""Search using Exa API with advanced parameters.""" | |
if not self.exa_client: | |
return [] | |
try: | |
# Configure Exa search parameters | |
search_params = { | |
'query': search_query.query, | |
'num_results': min(search_query.num_results, 20), | |
'include_domains': search_query.include_domains, | |
'exclude_domains': search_query.exclude_domains, | |
'use_autoprompt': True, # Let Exa optimize the query | |
'type': 'neural' # Use neural search for better semantic matching | |
} | |
# Add time filtering if specified | |
if search_query.time_range: | |
if search_query.time_range == 'recent': | |
search_params['start_published_date'] = (datetime.now() - timedelta(days=30)).isoformat() | |
elif search_query.time_range == 'year': | |
search_params['start_published_date'] = (datetime.now() - timedelta(days=365)).isoformat() | |
elif search_query.time_range == 'month': | |
search_params['start_published_date'] = (datetime.now() - timedelta(days=30)).isoformat() | |
elif search_query.time_range == 'week': | |
search_params['start_published_date'] = (datetime.now() - timedelta(days=7)).isoformat() | |
# Perform search | |
response = self.exa_client.search_and_contents(**search_params) | |
results = [] | |
for item in response.results: | |
try: | |
result = SearchResult( | |
title=item.title or "No title", | |
url=item.url, | |
content=item.text or "", | |
score=item.score if hasattr(item, 'score') else 0.5, | |
source="exa", | |
published_date=item.published_date if hasattr(item, 'published_date') else None, | |
author=item.author if hasattr(item, 'author') else None | |
) | |
results.append(result) | |
except Exception as e: | |
logger.warning(f"⚠️ Error processing Exa result: {e}") | |
continue | |
return results | |
except Exception as e: | |
logger.error(f"❌ Exa search error: {e}") | |
return [] | |
def _fallback_search(self, search_query: SearchQuery) -> List[SearchResult]: | |
"""Fallback search using DuckDuckGo or other methods.""" | |
if not WEB_SCRAPING_AVAILABLE: | |
return [] | |
try: | |
# Use DuckDuckGo search as fallback | |
from duckduckgo_search import DDGS | |
results = [] | |
with DDGS() as ddgs: | |
search_results = ddgs.text( | |
search_query.query, | |
max_results=min(search_query.num_results, 10) | |
) | |
for item in search_results: | |
try: | |
result = SearchResult( | |
title=item.get('title', 'No title'), | |
url=item.get('href', ''), | |
content=item.get('body', ''), | |
score=0.3, # Lower score for fallback results | |
source="duckduckgo" | |
) | |
results.append(result) | |
except Exception as e: | |
logger.warning(f"⚠️ Error processing DDG result: {e}") | |
continue | |
return results | |
except Exception as e: | |
logger.warning(f"⚠️ Fallback search error: {e}") | |
return [] | |
def _rank_results(self, results: List[SearchResult], search_query: SearchQuery) -> List[SearchResult]: | |
"""Rank search results by relevance and quality.""" | |
if not results: | |
return [] | |
# Calculate relevance scores | |
for result in results: | |
relevance_score = self._calculate_relevance(result, search_query) | |
quality_score = self._calculate_quality(result) | |
# Combine scores (weighted average) | |
result.score = (relevance_score * 0.7) + (quality_score * 0.3) | |
# Sort by score (descending) | |
ranked_results = sorted(results, key=lambda x: x.score, reverse=True) | |
# Remove duplicates based on URL | |
seen_urls = set() | |
unique_results = [] | |
for result in ranked_results: | |
if result.url not in seen_urls: | |
seen_urls.add(result.url) | |
unique_results.append(result) | |
# Return top results | |
return unique_results[:search_query.num_results] | |
def _calculate_relevance(self, result: SearchResult, search_query: SearchQuery) -> float: | |
"""Calculate relevance score based on query matching.""" | |
query_terms = search_query.query.lower().split() | |
title_lower = result.title.lower() | |
content_lower = result.content.lower() | |
# Count term matches in title (higher weight) | |
title_matches = sum(1 for term in query_terms if term in title_lower) | |
title_score = title_matches / len(query_terms) if query_terms else 0 | |
# Count term matches in content | |
content_matches = sum(1 for term in query_terms if term in content_lower) | |
content_score = content_matches / len(query_terms) if query_terms else 0 | |
# Combine scores | |
relevance = (title_score * 0.6) + (content_score * 0.4) | |
# Boost for exact phrase matches | |
if search_query.query.lower() in title_lower: | |
relevance += 0.3 | |
elif search_query.query.lower() in content_lower: | |
relevance += 0.2 | |
return min(relevance, 1.0) | |
def _calculate_quality(self, result: SearchResult) -> float: | |
"""Calculate quality score based on source and content characteristics.""" | |
quality = 0.5 # Base score | |
# Domain reputation boost | |
trusted_domains = [ | |
'wikipedia.org', 'britannica.com', 'reuters.com', 'bbc.com', | |
'cnn.com', 'nytimes.com', 'washingtonpost.com', 'theguardian.com', | |
'nature.com', 'science.org', 'arxiv.org', 'pubmed.ncbi.nlm.nih.gov' | |
] | |
if any(domain in result.domain for domain in trusted_domains): | |
quality += 0.3 | |
# Content length boost (longer content often more informative) | |
if len(result.content) > 500: | |
quality += 0.1 | |
elif len(result.content) > 1000: | |
quality += 0.2 | |
# Published date boost (recent content) | |
if result.published_date: | |
try: | |
pub_date = datetime.fromisoformat(result.published_date.replace('Z', '+00:00')) | |
days_old = (datetime.now() - pub_date.replace(tzinfo=None)).days | |
if days_old < 30: | |
quality += 0.1 | |
elif days_old < 365: | |
quality += 0.05 | |
except: | |
pass | |
# Source boost | |
if result.source == "exa": | |
quality += 0.1 | |
return min(quality, 1.0) | |
def _get_cache_key(self, search_query: SearchQuery) -> str: | |
"""Generate cache key for search query.""" | |
key_data = { | |
'query': search_query.query, | |
'type': search_query.query_type, | |
'time_range': search_query.time_range, | |
'num_results': search_query.num_results | |
} | |
return str(hash(json.dumps(key_data, sort_keys=True))) | |
def extract_content(self, url: str) -> Optional[str]: | |
"""Extract clean content from a URL.""" | |
if not WEB_SCRAPING_AVAILABLE: | |
return None | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers, timeout=10) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Remove script and style elements | |
for script in soup(["script", "style"]): | |
script.decompose() | |
# Get text content | |
text = soup.get_text() | |
# Clean up text | |
lines = (line.strip() for line in text.splitlines()) | |
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) | |
text = ' '.join(chunk for chunk in chunks if chunk) | |
return text[:5000] # Limit content length | |
except Exception as e: | |
logger.warning(f"⚠️ Content extraction failed for {url}: {e}") | |
return None | |
def search_for_factual_answer(self, question: str) -> Optional[str]: | |
""" | |
Search for a specific factual answer to a question. | |
Args: | |
question: The factual question to answer | |
Returns: | |
The most likely answer or None if not found | |
""" | |
# Create targeted search query | |
search_query = SearchQuery( | |
query=question, | |
query_type="factual", | |
num_results=5, | |
require_date=False | |
) | |
results = self.search(search_query) | |
if not results: | |
return None | |
# Extract potential answers from top results | |
answers = [] | |
for result in results[:3]: # Check top 3 results | |
content = result.content | |
if content: | |
# Look for direct answers in the content | |
answer = self._extract_answer_from_content(content, question) | |
if answer: | |
answers.append(answer) | |
# Return the most common answer or the first one found | |
if answers: | |
return answers[0] | |
return None | |
def _extract_answer_from_content(self, content: str, question: str) -> Optional[str]: | |
"""Extract a direct answer from content based on the question.""" | |
# This is a simplified answer extraction | |
# In a production system, you'd use more sophisticated NLP | |
sentences = content.split('.') | |
question_lower = question.lower() | |
# Look for sentences that might contain the answer | |
for sentence in sentences: | |
sentence = sentence.strip() | |
if len(sentence) > 10 and len(sentence) < 200: | |
# Check if sentence is relevant to the question | |
if any(word in sentence.lower() for word in question_lower.split() if len(word) > 3): | |
return sentence | |
return None | |
def get_search_suggestions(self, partial_query: str) -> List[str]: | |
"""Get search suggestions for a partial query.""" | |
# This would typically use a search suggestion API | |
# For now, return some basic suggestions | |
suggestions = [ | |
f"{partial_query} definition", | |
f"{partial_query} facts", | |
f"{partial_query} history", | |
f"{partial_query} recent news", | |
f"what is {partial_query}" | |
] | |
return suggestions[:5] |