|
import logging |
|
import re |
|
from smolagents import Tool |
|
from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SmartSearchTool(Tool): |
|
name = "smart_search" |
|
description = """A smart search tool that first performs a web search and then visits each URL to get its content.""" |
|
inputs = {"query": {"type": "string", "description": "The search query to find information"}} |
|
output_type = "string" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.web_search_tool = DuckDuckGoSearchTool(max_results=1) |
|
self.visit_webpage_tool = VisitWebpageTool(max_output_length=-1) |
|
|
|
def forward(self, query: str) -> str: |
|
logger.info(f"Starting smart search for query: {query}") |
|
|
|
|
|
web_search_results = self.web_search_tool.forward(query) |
|
logger.info(f"Web search results: {web_search_results[:100]}...") |
|
|
|
|
|
urls = re.findall(r'https?://[^\s)]+', web_search_results) |
|
if not urls: |
|
logger.info("No URLs found in web search result") |
|
return f"Web search results:\n{web_search_results}" |
|
|
|
|
|
contents = [] |
|
for url in urls: |
|
logger.info(f"Visiting URL: {url}") |
|
try: |
|
content = self.visit_webpage_tool.forward(url) |
|
if content: |
|
contents.append(f"\nContent from {url}:\n{content}") |
|
except Exception as e: |
|
logger.warning(f"Error visiting {url}: {e}") |
|
contents.append(f"\nError visiting {url}: {e}") |
|
|
|
|
|
return f"Web search results:\n{web_search_results}\n" + "\n".join(contents) |
|
|
|
|
|
def main(query: str) -> str: |
|
""" |
|
Test function to run the SmartSearchTool directly. |
|
|
|
Args: |
|
query: The search query to test |
|
|
|
Returns: |
|
The search results |
|
""" |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
tool = SmartSearchTool() |
|
result = tool.forward(query) |
|
|
|
|
|
print("\nSearch Results:") |
|
print("-" * 80) |
|
print(result) |
|
print("-" * 80) |
|
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
if len(sys.argv) > 1: |
|
query = " ".join(sys.argv[1:]) |
|
main(query) |
|
else: |
|
print("Usage: python tool.py <search query>") |
|
print("Example: python tool.py 'Mercedes Sosa discography'") |
|
|