File size: 1,628 Bytes
4701375
751d628
 
 
 
4701375
 
 
751d628
 
 
 
 
 
 
 
 
 
 
 
 
4701375
751d628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4701375
 
751d628
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import logging
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from datasets import load_dataset
from rank_bm25 import BM25Okapi

logger = logging.getLogger(__name__)

class GuestInfoInput(BaseModel):
    query: str = Field(description="Query about guest information")

async def guest_info_func(query: str) -> str:
    """
    Retrieve guest information based on a query.
    
    Args:
        query (str): Query about guest information.
    
    Returns:
        str: Guest information or error message.
    """
    try:
        logger.info(f"Retrieving guest info for query: {query}")
        dataset = load_dataset("agents-course/unit3-invitees", split="train")
        logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset")
        
        documents = [f"{row['name']} {row['relation']}" for row in dataset]
        tokenized_docs = [doc.lower().split() for doc in documents]
        bm25 = BM25Okapi(tokenized_docs)
        
        tokenized_query = query.lower().split()
        scores = bm25.get_scores(tokenized_query)
        best_idx = scores.argmax()
        
        if scores[best_idx] > 0:
            return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}"
        return "No matching guest found"
    except Exception as e:
        logger.error(f"Error retrieving guest info for query '{query}': {e}")
        return f"Error: {str(e)}"

guest_info_retriever_tool = StructuredTool.from_function(
    func=guest_info_func,
    name="guest_info_retriever_tool",
    args_schema=GuestInfoInput,
    coroutine=guest_info_func
)