File size: 3,682 Bytes
1bbca12
488dc3e
 
 
 
 
1bbca12
 
488dc3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bbca12
 
488dc3e
 
1bbca12
488dc3e
 
 
 
 
 
 
 
1bbca12
488dc3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bbca12
488dc3e
 
 
 
 
 
 
 
 
 
 
 
 
 
1bbca12
488dc3e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from langchain_core.tools import tool
from langchain_huggingface import HuggingFacePipeline
from sentence_transformers import SentenceTransformer
import logging
from typing import List, Dict, Any
import requests
import os

logger = logging.getLogger(__name__)

# Initialize embedding model (free, open-source)
try:
    embedder = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
    logger.error(f"Failed to initialize embedding model: {e}")
    embedder = None

# Global LLM instance
search_llm = None

def initialize_search_tools(llm: HuggingFacePipeline) -> None:
    """Initialize search tools with the provided LLM"""
    global search_llm
    search_llm = llm
    logger.info("Search tools initialized with HuggingFace LLM")

@tool
async def search_tool(query: str) -> List[Dict[str, Any]]:
    """Perform a web search using the query"""
    try:
        if not search_llm:
            logger.warning("Search LLM not initialized")
            return [{"content": "Search unavailable", "url": ""}]
        
        # Refine query using LLM
        prompt = f"Refine this search query for better results: {query}"
        response = await search_llm.ainvoke(prompt)
        refined_query = response.content.strip()

        # Check for SerpAPI key (free tier available)
        serpapi_key = os.getenv("SERPAPI_API_KEY")
        if serpapi_key:
            try:
                params = {"q": refined_query, "api_key": serpapi_key}
                response = requests.get("https://serpapi.com/search", params=params)
                response.raise_for_status()
                results = response.json().get("organic_results", [])
                return [{"content": r.get("snippet", ""), "url": r.get("link", "")} for r in results]
            except Exception as e:
                logger.warning(f"SerpAPI failed: {e}, falling back to mock search")
        
        # Mock search if no API key or API fails
        if embedder:
            query_embedding = embedder.encode(refined_query)
            results = [
                {"content": f"Mock result for {refined_query}", "url": "https://example.com"},
                {"content": f"Another mock result for {refined_query}", "url": "https://example.org"}
            ]
        else:
            results = [{"content": "Embedding model unavailable", "url": ""}]
        
        logger.info(f"Search results for query '{refined_query}': {len(results)} items")
        return results
    except Exception as e:
        logger.error(f"Error in search_tool: {e}")
        return [{"content": f"Search failed: {str(e)}", "url": ""}]

@tool
async def multi_hop_search_tool(query: str, steps: int = 3) -> List[Dict[str, Any]]:
    """Perform a multi-hop search by iteratively refining the query"""
    try:
        if not search_llm:
            logger.warning("Search LLM not initialized")
            return [{"content": "Multi-hop search unavailable", "url": ""}]
        
        results = []
        current_query = query
        for step in range(steps):
            prompt = f"Based on the query '{current_query}', generate a follow-up question to deepen the search."
            response = await search_llm.ainvoke(prompt)
            next_query = response.content.strip()
            
            step_results = await search_tool.invoke({"query": next_query})
            results.extend(step_results)
            current_query = next_query
            logger.info(f"Multi-hop step {step + 1}: {next_query}")
        
        return results
    except Exception as e:
        logger.error(f"Error in multi_hop_search_tool: {e}")
        return [{"content": f"Multi-hop search failed: {str(e)}", "url": ""}]