import json import os import ssl import aiohttp import asyncio from agents import function_tool # from ..workers.baseclass import ResearchAgent, ResearchRunner # from ..workers.utils.parse_output import create_type_parser from typing import List, Union, Optional from bs4 import BeautifulSoup from dotenv import load_dotenv from pydantic import BaseModel, Field from crawl4ai import * load_dotenv() CONTENT_LENGTH_LIMIT = 10000 # Trim scraped content to this length to avoid large context / token limit issues SEARCH_PROVIDER = os.getenv("SEARCH_PROVIDER", "serper").lower() # ------- DEFINE TYPES ------- class ScrapeResult(BaseModel): url: str = Field(description="The URL of the webpage") text: str = Field(description="The full text content of the webpage") title: str = Field(description="The title of the webpage") description: str = Field(description="A short description of the webpage") class WebpageSnippet(BaseModel): url: str = Field(description="The URL of the webpage") title: str = Field(description="The title of the webpage") description: Optional[str] = Field(description="A short description of the webpage") class SearchResults(BaseModel): results_list: List[WebpageSnippet] # ------- DEFINE TOOL ------- # Add a module-level variable to store the singleton instance _serper_client = None @function_tool async def web_search(query: str) -> Union[List[ScrapeResult], str]: """Perform a web search for a given query and get back the URLs along with their titles, descriptions and text contents. Args: query: The search query Returns: List of ScrapeResult objects which have the following fields: - url: The URL of the search result - title: The title of the search result - description: The description of the search result - text: The full text content of the search result """ # Only use SerperClient if search provider is serper if SEARCH_PROVIDER == "openai": # For OpenAI search provider, this function should not be called directly # The WebSearchTool from the agents module will be used instead return f"The web_search function is not used when SEARCH_PROVIDER is set to 'openai'. Please check your configuration." else: try: # Lazy initialization of SerperClient global _serper_client if _serper_client is None: _serper_client = SerperClient() search_results = await _serper_client.search( query, filter_for_relevance=True, max_results=5 ) results = await scrape_urls(search_results) return results except Exception as e: # Return a user-friendly error message return f"Sorry, I encountered an error while searching: {str(e)}" # ------- DEFINE AGENT FOR FILTERING SEARCH RESULTS BY RELEVANCE ------- FILTER_AGENT_INSTRUCTIONS = f""" You are a search result filter. Your task is to analyze a list of SERP search results and determine which ones are relevant to the original query based on the link, title and snippet. Return only the relevant results in the specified format. - Remove any results that refer to entities that have similar names to the queried entity, but are not the same. - E.g. if the query asks about a company "Amce Inc, acme.com", remove results with "acmesolutions.com" or "acme.net" in the link. Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only: {SearchResults.model_json_schema()} """ # selected_model = fast_model # # filter_agent = ResearchAgent( # name="SearchFilterAgent", # instructions=FILTER_AGENT_INSTRUCTIONS, # model=selected_model, # output_type=SearchResults if model_supports_structured_output(selected_model) else None, # output_parser=create_type_parser(SearchResults) if not model_supports_structured_output(selected_model) else None # ) # ------- DEFINE UNDERLYING TOOL LOGIC ------- # Create a shared connector ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE ssl_context.set_ciphers( "DEFAULT:@SECLEVEL=1" ) # Add this line to allow older cipher suites class SerperClient: """A client for the Serper API to perform Google searches.""" def __init__(self, api_key: str = None): self.api_key = api_key or os.getenv("SERPER_API_KEY") if not self.api_key: raise ValueError( "No API key provided. Set SERPER_API_KEY environment variable." ) self.url = "https://google.serper.dev/search" self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} async def search( self, query: str, filter_for_relevance: bool = True, max_results: int = 5 ) -> List[WebpageSnippet]: """Perform a Google search using Serper API and fetch basic details for top results. Args: query: The search query num_results: Maximum number of results to return (max 10) Returns: Dictionary with search results """ connector = aiohttp.TCPConnector(ssl=ssl_context) async with aiohttp.ClientSession(connector=connector) as session: async with session.post( self.url, headers=self.headers, json={"q": query, "autocorrect": False} ) as response: response.raise_for_status() results = await response.json() results_list = [ WebpageSnippet( url=result.get("link", ""), title=result.get("title", ""), description=result.get("snippet", ""), ) for result in results.get("organic", []) ] if not results_list: return [] if not filter_for_relevance: return results_list[:max_results] # return results_list[:max_results] return await self._filter_results(results_list, query, max_results=max_results) async def _filter_results( self, results: List[WebpageSnippet], query: str, max_results: int = 5 ) -> List[WebpageSnippet]: # get rid of pubmed source data filtered_results = [ res for res in results if "pmc.ncbi.nlm.nih.gov" not in res.url and "pubmed.ncbi.nlm.nih.gov" not in res.url ] # # get rid of unrelated data # serialized_results = [result.model_dump() if isinstance(result, WebpageSnippet) else result for result in # filtered_results] # # user_prompt = f""" # Original search query: {query} # # Search results to analyze: # {json.dumps(serialized_results, indent=2)} # # Return {max_results} search results or less. # """ # # try: # result = await ResearchRunner.run(filter_agent, user_prompt) # output = result.final_output_as(SearchResults) # return output.results_list # except Exception as e: # print("Error filtering urls:", str(e)) # return filtered_results[:max_results] async def fetch_url(session, url): try: async with session.get(url, timeout=5) as response: return response.status == 200 except Exception as e: print(f"Error accessing {url}: {str(e)}") return False # 返回 False 表示不可访问 async def filter_unreachable_urls(results): async with aiohttp.ClientSession() as session: tasks = [fetch_url(session, res.url) for res in results] reachable = await asyncio.gather(*tasks) return [ res for res, can_access in zip(results, reachable) if can_access ] reachable_results = await filter_unreachable_urls(filtered_results) # Return the first `max_results` or less if there are not enough reachable results return reachable_results[:max_results] async def scrape_urls(items: List[WebpageSnippet]) -> List[ScrapeResult]: """Fetch text content from provided URLs. Args: items: List of SearchEngineResult items to extract content from Returns: List of ScrapeResult objects which have the following fields: - url: The URL of the search result - title: The title of the search result - description: The description of the search result - text: The full text content of the search result """ connector = aiohttp.TCPConnector(ssl=ssl_context) async with aiohttp.ClientSession(connector=connector) as session: # Create list of tasks for concurrent execution tasks = [] for item in items: if item.url: # Skip empty URLs tasks.append(fetch_and_process_url(session, item)) # Execute all tasks concurrently and gather results results = await asyncio.gather(*tasks, return_exceptions=True) # Filter out errors and return successful results return [r for r in results if isinstance(r, ScrapeResult)] async def fetch_and_process_url( session: aiohttp.ClientSession, item: WebpageSnippet ) -> ScrapeResult: """Helper function to fetch and process a single URL.""" if not is_valid_url(item.url): return ScrapeResult( url=item.url, title=item.title, description=item.description, text=f"Error fetching content: URL contains restricted file extension", ) try: async with session.get(item.url, timeout=8) as response: if response.status == 200: content = await response.text() # Run html_to_text in a thread pool to avoid blocking text_content = await asyncio.get_event_loop().run_in_executor( None, html_to_text, content ) text_content = text_content[ :CONTENT_LENGTH_LIMIT ] # Trim content to avoid exceeding token limit return ScrapeResult( url=item.url, title=item.title, description=item.description, text=text_content, ) else: # Instead of raising, return a WebSearchResult with an error message return ScrapeResult( url=item.url, title=item.title, description=item.description, text=f"Error fetching content: HTTP {response.status}", ) except Exception as e: # Instead of raising, return a WebSearchResult with an error message return ScrapeResult( url=item.url, title=item.title, description=item.description, text=f"Error fetching content: {str(e)}", ) def html_to_text(html_content: str) -> str: """ Strips out all of the unnecessary elements from the HTML context to prepare it for text extraction / LLM processing. """ # Parse the HTML using lxml for speed soup = BeautifulSoup(html_content, "lxml") # Extract text from relevant tags tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote") # Use a generator expression for efficiency extracted_text = "\n".join( element.get_text(strip=True) for element in soup.find_all(tags_to_extract) if element.get_text(strip=True) ) return extracted_text def is_valid_url(url: str) -> bool: """Check that a URL does not contain restricted file extensions.""" if any( ext in url for ext in [ ".pdf", ".doc", ".xls", ".ppt", ".zip", ".rar", ".7z", ".txt", ".js", ".xml", ".css", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".webp", ".mp3", ".mp4", ".avi", ".mov", ".wmv", ".flv", ".wma", ".wav", ".m4a", ".m4v", ".m4b", ".m4p", ".m4u", ] ): return False return True async def url_to_contents(url): async with AsyncWebCrawler() as crawler: result = await crawler.arun( url=url, ) # print(result.markdown) return result.markdown async def url_to_fit_contents(res): str_fit_max = 40000 # 40,000字符通常在10,000token,5个合起来不超过50k browser_config = BrowserConfig( headless=True, verbose=True, ) run_config = CrawlerRunConfig( cache_mode=CacheMode.DISABLED, markdown_generator=DefaultMarkdownGenerator( content_filter=PruningContentFilter( threshold=1.0, threshold_type="fixed", min_word_threshold=0 ) ), # markdown_generator=DefaultMarkdownGenerator( # content_filter=BM25ContentFilter(user_query="WHEN_WE_FOCUS_BASED_ON_A_USER_QUERY", bm25_threshold=1.0) # ), ) try: async with AsyncWebCrawler(config=browser_config) as crawler: # 使用 asyncio.wait_for 来设置超时 result = await asyncio.wait_for( crawler.arun(url=res.url, config=run_config), timeout=15 # 设置超时 ) print(f"char before filtering {len(result.markdown.raw_markdown)}.") print(f"char after filtering {len(result.markdown.fit_markdown)}.") return result.markdown.fit_markdown[ :str_fit_max ] # 如果成功,返回结果的前str_fit_max个字符 except asyncio.TimeoutError: print(f"Timeout occurred while accessing {res.url}.") # 打印超时信息 return res.text[:str_fit_max] # 如果发生超时,返回res粗略提取 except Exception as e: print(f"Exception occurred: {str(e)}") # 打印其他异常信息 return res.text[:str_fit_max] # 如果发生其他异常,返回res粗略提取