Spaces:
Starting
Starting
import logging | |
from langchain_core.tools import StructuredTool | |
from pydantic import BaseModel, Field | |
from datasets import load_dataset | |
from rank_bm25 import BM25Okapi | |
logger = logging.getLogger(__name__) | |
class GuestInfoInput(BaseModel): | |
query: str = Field(description="Query about guest information") | |
async def guest_info_func(query: str) -> str: | |
""" | |
Retrieve guest information based on a query. | |
Args: | |
query (str): Query about guest information. | |
Returns: | |
str: Guest information or error message. | |
""" | |
try: | |
logger.info(f"Retrieving guest info for query: {query}") | |
dataset = load_dataset("agents-course/unit3-invitees", split="train") | |
logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset") | |
documents = [f"{row['name']} {row['relation']}" for row in dataset] | |
tokenized_docs = [doc.lower().split() for doc in documents] | |
bm25 = BM25Okapi(tokenized_docs) | |
tokenized_query = query.lower().split() | |
scores = bm25.get_scores(tokenized_query) | |
best_idx = scores.argmax() | |
if scores[best_idx] > 0: | |
return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}" | |
return "No matching guest found" | |
except Exception as e: | |
logger.error(f"Error retrieving guest info for query '{query}': {e}") | |
return f"Error: {str(e)}" | |
guest_info_retriever_tool = StructuredTool.from_function( | |
func=guest_info_func, | |
name="guest_info_retriever_tool", | |
args_schema=GuestInfoInput, | |
coroutine=guest_info_func | |
) |