File size: 4,372 Bytes
640b1c8
0739c8b
640b1c8
 
e87abff
 
 
0739c8b
640b1c8
 
 
 
 
 
 
 
 
 
 
 
0739c8b
 
 
 
 
 
 
 
 
 
640b1c8
 
 
 
0739c8b
640b1c8
 
 
 
 
 
 
 
0739c8b
640b1c8
 
 
 
0739c8b
 
640b1c8
 
 
 
0739c8b
 
 
 
 
 
 
 
 
 
 
 
 
640b1c8
0739c8b
640b1c8
0739c8b
 
 
640b1c8
 
 
 
 
 
 
0739c8b
 
640b1c8
 
 
 
 
 
 
0739c8b
 
 
 
640b1c8
 
 
 
0739c8b
 
 
 
 
 
640b1c8
 
0739c8b
 
 
 
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# src/agents/rag_agent.py
from typing import List, Optional, Tuple, Dict

from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.text_splitter import split_text
from src.models.rag import RAGResponse

class RAGAgent:
    def __init__(
        self, 
        llm: BaseLLM, 
        embedding: BaseEmbedding, 
        vector_store: BaseVectorStore
    ):
        self.llm = llm
        self.embedding = embedding
        self.vector_store = vector_store
    
    def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
        """Convert numeric metadata values to strings"""
        converted = {}
        for key, value in metadata.items():
            if isinstance(value, (int, float)):
                converted[key] = str(value)
            else:
                converted[key] = value
        return converted
    
    def retrieve_context(
        self, 
        query: str, 
        top_k: int = 3
    ) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
        """
        Retrieve relevant context documents for a given query
        
        Args:
            query (str): Input query to find context for
            top_k (int): Number of top context documents to retrieve
        
        Returns:
            Tuple[List[str], List[Dict], Optional[List[float]]]: Retrieved documents, sources, and scores
        """
        # Embed the query
        query_embedding = self.embedding.embed_query(query)
        
        # Retrieve similar documents with metadata and scores
        results = self.vector_store.similarity_search(
            query_embedding, 
            top_k=top_k
        )
        
        # Extract documents, sources, and scores from results
        documents = [doc['text'] for doc in results]
        
        # Convert numeric metadata values to strings
        sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results]
        
        scores = [doc['score'] for doc in results if doc.get('score') is not None]
        
        # Only return scores if we have them for all documents
        if len(scores) != len(documents):
            scores = None
            
        return documents, sources, scores
    
    async def generate_response(
        self, 
        query: str,
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
        context_docs: Optional[List[str]] = None
    ) -> RAGResponse:
        """
        Generate a response using RAG approach
        
        Args:
            query (str): User input query
            temperature (float): Sampling temperature for the LLM
            max_tokens (Optional[int]): Maximum tokens to generate
            context_docs (Optional[List[str]]): Optional pre-provided context documents
        
        Returns:
            RAGResponse: Response with generated text and context
        """
        # If no context provided, retrieve from vector store
        if not context_docs:
            context_docs, sources, scores = self.retrieve_context(query)
        else:
            sources = None
            scores = None
        
        # Construct augmented prompt with context
        augmented_prompt = self._construct_prompt(query, context_docs)
        
        # Generate response using LLM with temperature
        response = self.llm.generate(
            augmented_prompt,
            temperature=temperature,
            max_tokens=max_tokens
        )
        
        return RAGResponse(
            response=response,
            context_docs=context_docs,
            sources=sources,
            scores=scores
        )
    
    def _construct_prompt(
        self, 
        query: str, 
        context_docs: List[str]
    ) -> str:
        """
        Construct a prompt with retrieved context
        
        Args:
            query (str): Original user query
            context_docs (List[str]): Retrieved context documents
        
        Returns:
            str: Augmented prompt for the LLM
        """
        context_str = "\n\n".join(context_docs)
        
        return f"""
        Context Information:
        {context_str}
        
        User Query: {query}
        
        Based on the context, please provide a comprehensive and accurate response.
        """