Spaces:
Sleeping
Sleeping
import logging | |
import re | |
from smolagents import Tool | |
from smolagents.default_tools import DuckDuckGoSearchTool, WikipediaSearchTool | |
logger = logging.getLogger(__name__) | |
class SmartSearchTool(Tool): | |
name = "smart_search" | |
description = """A smart search tool that first performs a web search and then, if a Wikipedia article is found, | |
uses Wikipedia search for more reliable information.""" | |
inputs = {"query": {"type": "string", "description": "The search query to find information"}} | |
output_type = "string" | |
def __init__(self): | |
super().__init__() | |
self.web_search_tool = DuckDuckGoSearchTool(max_results=1) | |
self.wiki_tool = WikipediaSearchTool( | |
user_agent="SmartSearchTool (smartsearch@example.com)", | |
language="en", | |
# content_type="summary", | |
content_type="text", | |
extract_format="WIKI" | |
) | |
def forward(self, query: str) -> str: | |
logger.info(f"Starting smart search for query: {query}") | |
# First perform a web search with a single result | |
web_result = self.web_search_tool.forward(query) | |
logger.info(f"Web search result: {web_result[:100]}...") | |
# Check if the result contains a Wikipedia link | |
if "wikipedia.org" in web_result.lower(): | |
logger.info("Wikipedia link found in web search results") | |
# Extract the Wikipedia page title from the URL using regex | |
wiki_match = re.search(r'wikipedia\.org/wiki/([^)\s]+)', web_result) | |
if wiki_match: | |
wiki_title = wiki_match.group(1) | |
logger.info(f"Extracted Wikipedia title: {wiki_title}") | |
# Use Wikipedia search for more reliable information | |
wiki_result = self.wiki_tool.forward(wiki_title) | |
logger.info(f"Wikipedia search result: {wiki_result[:100]}...") | |
if wiki_result and "No Wikipedia page found" not in wiki_result: | |
logger.info("Successfully retrieved Wikipedia content") | |
return f"Web search result:\n{web_result}\n\nWikipedia result:\n{wiki_result}" | |
else: | |
logger.warning("Wikipedia search failed or returned no results") | |
else: | |
logger.warning("Could not extract Wikipedia title from URL") | |
# If no Wikipedia link was found or Wikipedia search failed, return the web search result | |
logger.info("Returning web search result only") | |
return f"Web search result:\n{web_result}" | |
def main(query: str) -> str: | |
""" | |
Test function to run the SmartSearchTool directly. | |
Args: | |
query: The search query to test | |
Returns: | |
The search results | |
""" | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# Create and run the tool | |
tool = SmartSearchTool() | |
result = tool.forward(query) | |
# Print the result | |
print("\nSearch Results:") | |
print("-" * 80) | |
print(result) | |
print("-" * 80) | |
return result | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) > 1: | |
query = " ".join(sys.argv[1:]) | |
main(query) | |
else: | |
print("Usage: python tool.py <search query>") | |
print("Example: python tool.py 'Mercedes Sosa discography'") | |