jackkuo's picture
add QA
79899c0
import json
import os
import ssl
import aiohttp
import asyncio
from agents import function_tool
# from ..workers.baseclass import ResearchAgent, ResearchRunner
# from ..workers.utils.parse_output import create_type_parser
from typing import List, Union, Optional
from bs4 import BeautifulSoup
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from crawl4ai import *
load_dotenv()
CONTENT_LENGTH_LIMIT = 10000 # Trim scraped content to this length to avoid large context / token limit issues
SEARCH_PROVIDER = os.getenv("SEARCH_PROVIDER", "serper").lower()
# ------- DEFINE TYPES -------
class ScrapeResult(BaseModel):
url: str = Field(description="The URL of the webpage")
text: str = Field(description="The full text content of the webpage")
title: str = Field(description="The title of the webpage")
description: str = Field(description="A short description of the webpage")
class WebpageSnippet(BaseModel):
url: str = Field(description="The URL of the webpage")
title: str = Field(description="The title of the webpage")
description: Optional[str] = Field(description="A short description of the webpage")
class SearchResults(BaseModel):
results_list: List[WebpageSnippet]
# ------- DEFINE TOOL -------
# Add a module-level variable to store the singleton instance
_serper_client = None
@function_tool
async def web_search(query: str) -> Union[List[ScrapeResult], str]:
"""Perform a web search for a given query and get back the URLs along with their titles, descriptions and text contents.
Args:
query: The search query
Returns:
List of ScrapeResult objects which have the following fields:
- url: The URL of the search result
- title: The title of the search result
- description: The description of the search result
- text: The full text content of the search result
"""
# Only use SerperClient if search provider is serper
if SEARCH_PROVIDER == "openai":
# For OpenAI search provider, this function should not be called directly
# The WebSearchTool from the agents module will be used instead
return f"The web_search function is not used when SEARCH_PROVIDER is set to 'openai'. Please check your configuration."
else:
try:
# Lazy initialization of SerperClient
global _serper_client
if _serper_client is None:
_serper_client = SerperClient()
search_results = await _serper_client.search(
query, filter_for_relevance=True, max_results=5
)
results = await scrape_urls(search_results)
return results
except Exception as e:
# Return a user-friendly error message
return f"Sorry, I encountered an error while searching: {str(e)}"
# ------- DEFINE AGENT FOR FILTERING SEARCH RESULTS BY RELEVANCE -------
FILTER_AGENT_INSTRUCTIONS = f"""
You are a search result filter. Your task is to analyze a list of SERP search results and determine which ones are relevant
to the original query based on the link, title and snippet. Return only the relevant results in the specified format.
- Remove any results that refer to entities that have similar names to the queried entity, but are not the same.
- E.g. if the query asks about a company "Amce Inc, acme.com", remove results with "acmesolutions.com" or "acme.net" in the link.
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only:
{SearchResults.model_json_schema()}
"""
# selected_model = fast_model
#
# filter_agent = ResearchAgent(
# name="SearchFilterAgent",
# instructions=FILTER_AGENT_INSTRUCTIONS,
# model=selected_model,
# output_type=SearchResults if model_supports_structured_output(selected_model) else None,
# output_parser=create_type_parser(SearchResults) if not model_supports_structured_output(selected_model) else None
# )
# ------- DEFINE UNDERLYING TOOL LOGIC -------
# Create a shared connector
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_context.set_ciphers(
"DEFAULT:@SECLEVEL=1"
) # Add this line to allow older cipher suites
class SerperClient:
"""A client for the Serper API to perform Google searches."""
def __init__(self, api_key: str = None):
self.api_key = api_key or os.getenv("SERPER_API_KEY")
if not self.api_key:
raise ValueError(
"No API key provided. Set SERPER_API_KEY environment variable."
)
self.url = "https://google.serper.dev/search"
self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
async def search(
self, query: str, filter_for_relevance: bool = True, max_results: int = 5
) -> List[WebpageSnippet]:
"""Perform a Google search using Serper API and fetch basic details for top results.
Args:
query: The search query
num_results: Maximum number of results to return (max 10)
Returns:
Dictionary with search results
"""
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(
self.url, headers=self.headers, json={"q": query, "autocorrect": False}
) as response:
response.raise_for_status()
results = await response.json()
results_list = [
WebpageSnippet(
url=result.get("link", ""),
title=result.get("title", ""),
description=result.get("snippet", ""),
)
for result in results.get("organic", [])
]
if not results_list:
return []
if not filter_for_relevance:
return results_list[:max_results]
# return results_list[:max_results]
return await self._filter_results(results_list, query, max_results=max_results)
async def _filter_results(
self, results: List[WebpageSnippet], query: str, max_results: int = 5
) -> List[WebpageSnippet]:
# get rid of pubmed source data
filtered_results = [
res
for res in results
if "pmc.ncbi.nlm.nih.gov" not in res.url
and "pubmed.ncbi.nlm.nih.gov" not in res.url
]
# # get rid of unrelated data
# serialized_results = [result.model_dump() if isinstance(result, WebpageSnippet) else result for result in
# filtered_results]
#
# user_prompt = f"""
# Original search query: {query}
#
# Search results to analyze:
# {json.dumps(serialized_results, indent=2)}
#
# Return {max_results} search results or less.
# """
#
# try:
# result = await ResearchRunner.run(filter_agent, user_prompt)
# output = result.final_output_as(SearchResults)
# return output.results_list
# except Exception as e:
# print("Error filtering urls:", str(e))
# return filtered_results[:max_results]
async def fetch_url(session, url):
try:
async with session.get(url, timeout=5) as response:
return response.status == 200
except Exception as e:
print(f"Error accessing {url}: {str(e)}")
return False # 返回 False 表示不可访问
async def filter_unreachable_urls(results):
async with aiohttp.ClientSession() as session:
tasks = [fetch_url(session, res.url) for res in results]
reachable = await asyncio.gather(*tasks)
return [
res for res, can_access in zip(results, reachable) if can_access
]
reachable_results = await filter_unreachable_urls(filtered_results)
# Return the first `max_results` or less if there are not enough reachable results
return reachable_results[:max_results]
async def scrape_urls(items: List[WebpageSnippet]) -> List[ScrapeResult]:
"""Fetch text content from provided URLs.
Args:
items: List of SearchEngineResult items to extract content from
Returns:
List of ScrapeResult objects which have the following fields:
- url: The URL of the search result
- title: The title of the search result
- description: The description of the search result
- text: The full text content of the search result
"""
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
# Create list of tasks for concurrent execution
tasks = []
for item in items:
if item.url: # Skip empty URLs
tasks.append(fetch_and_process_url(session, item))
# Execute all tasks concurrently and gather results
results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out errors and return successful results
return [r for r in results if isinstance(r, ScrapeResult)]
async def fetch_and_process_url(
session: aiohttp.ClientSession, item: WebpageSnippet
) -> ScrapeResult:
"""Helper function to fetch and process a single URL."""
if not is_valid_url(item.url):
return ScrapeResult(
url=item.url,
title=item.title,
description=item.description,
text=f"Error fetching content: URL contains restricted file extension",
)
try:
async with session.get(item.url, timeout=8) as response:
if response.status == 200:
content = await response.text()
# Run html_to_text in a thread pool to avoid blocking
text_content = await asyncio.get_event_loop().run_in_executor(
None, html_to_text, content
)
text_content = text_content[
:CONTENT_LENGTH_LIMIT
] # Trim content to avoid exceeding token limit
return ScrapeResult(
url=item.url,
title=item.title,
description=item.description,
text=text_content,
)
else:
# Instead of raising, return a WebSearchResult with an error message
return ScrapeResult(
url=item.url,
title=item.title,
description=item.description,
text=f"Error fetching content: HTTP {response.status}",
)
except Exception as e:
# Instead of raising, return a WebSearchResult with an error message
return ScrapeResult(
url=item.url,
title=item.title,
description=item.description,
text=f"Error fetching content: {str(e)}",
)
def html_to_text(html_content: str) -> str:
"""
Strips out all of the unnecessary elements from the HTML context to prepare it for text extraction / LLM processing.
"""
# Parse the HTML using lxml for speed
soup = BeautifulSoup(html_content, "lxml")
# Extract text from relevant tags
tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote")
# Use a generator expression for efficiency
extracted_text = "\n".join(
element.get_text(strip=True)
for element in soup.find_all(tags_to_extract)
if element.get_text(strip=True)
)
return extracted_text
def is_valid_url(url: str) -> bool:
"""Check that a URL does not contain restricted file extensions."""
if any(
ext in url
for ext in [
".pdf",
".doc",
".xls",
".ppt",
".zip",
".rar",
".7z",
".txt",
".js",
".xml",
".css",
".png",
".jpg",
".jpeg",
".gif",
".ico",
".svg",
".webp",
".mp3",
".mp4",
".avi",
".mov",
".wmv",
".flv",
".wma",
".wav",
".m4a",
".m4v",
".m4b",
".m4p",
".m4u",
]
):
return False
return True
async def url_to_contents(url):
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url=url,
)
# print(result.markdown)
return result.markdown
async def url_to_fit_contents(res):
str_fit_max = 40000 # 40,000字符通常在10,000token,5个合起来不超过50k
browser_config = BrowserConfig(
headless=True,
verbose=True,
)
run_config = CrawlerRunConfig(
cache_mode=CacheMode.DISABLED,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=1.0, threshold_type="fixed", min_word_threshold=0
)
),
# markdown_generator=DefaultMarkdownGenerator(
# content_filter=BM25ContentFilter(user_query="WHEN_WE_FOCUS_BASED_ON_A_USER_QUERY", bm25_threshold=1.0)
# ),
)
try:
async with AsyncWebCrawler(config=browser_config) as crawler:
# 使用 asyncio.wait_for 来设置超时
result = await asyncio.wait_for(
crawler.arun(url=res.url, config=run_config), timeout=15 # 设置超时
)
print(f"char before filtering {len(result.markdown.raw_markdown)}.")
print(f"char after filtering {len(result.markdown.fit_markdown)}.")
return result.markdown.fit_markdown[
:str_fit_max
] # 如果成功,返回结果的前str_fit_max个字符
except asyncio.TimeoutError:
print(f"Timeout occurred while accessing {res.url}.") # 打印超时信息
return res.text[:str_fit_max] # 如果发生超时,返回res粗略提取
except Exception as e:
print(f"Exception occurred: {str(e)}") # 打印其他异常信息
return res.text[:str_fit_max] # 如果发生其他异常,返回res粗略提取