Spaces:
Sleeping
Sleeping
File size: 4,372 Bytes
640b1c8 0739c8b 640b1c8 e87abff 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 0739c8b 640b1c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# src/agents/rag_agent.py
from typing import List, Optional, Tuple, Dict
from ..llms.base_llm import BaseLLM
from src.embeddings.base_embedding import BaseEmbedding
from src.vectorstores.base_vectorstore import BaseVectorStore
from src.utils.text_splitter import split_text
from src.models.rag import RAGResponse
class RAGAgent:
def __init__(
self,
llm: BaseLLM,
embedding: BaseEmbedding,
vector_store: BaseVectorStore
):
self.llm = llm
self.embedding = embedding
self.vector_store = vector_store
def _convert_metadata_to_strings(self, metadata: Dict) -> Dict:
"""Convert numeric metadata values to strings"""
converted = {}
for key, value in metadata.items():
if isinstance(value, (int, float)):
converted[key] = str(value)
else:
converted[key] = value
return converted
def retrieve_context(
self,
query: str,
top_k: int = 3
) -> Tuple[List[str], List[Dict], Optional[List[float]]]:
"""
Retrieve relevant context documents for a given query
Args:
query (str): Input query to find context for
top_k (int): Number of top context documents to retrieve
Returns:
Tuple[List[str], List[Dict], Optional[List[float]]]: Retrieved documents, sources, and scores
"""
# Embed the query
query_embedding = self.embedding.embed_query(query)
# Retrieve similar documents with metadata and scores
results = self.vector_store.similarity_search(
query_embedding,
top_k=top_k
)
# Extract documents, sources, and scores from results
documents = [doc['text'] for doc in results]
# Convert numeric metadata values to strings
sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results]
scores = [doc['score'] for doc in results if doc.get('score') is not None]
# Only return scores if we have them for all documents
if len(scores) != len(documents):
scores = None
return documents, sources, scores
async def generate_response(
self,
query: str,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
context_docs: Optional[List[str]] = None
) -> RAGResponse:
"""
Generate a response using RAG approach
Args:
query (str): User input query
temperature (float): Sampling temperature for the LLM
max_tokens (Optional[int]): Maximum tokens to generate
context_docs (Optional[List[str]]): Optional pre-provided context documents
Returns:
RAGResponse: Response with generated text and context
"""
# If no context provided, retrieve from vector store
if not context_docs:
context_docs, sources, scores = self.retrieve_context(query)
else:
sources = None
scores = None
# Construct augmented prompt with context
augmented_prompt = self._construct_prompt(query, context_docs)
# Generate response using LLM with temperature
response = self.llm.generate(
augmented_prompt,
temperature=temperature,
max_tokens=max_tokens
)
return RAGResponse(
response=response,
context_docs=context_docs,
sources=sources,
scores=scores
)
def _construct_prompt(
self,
query: str,
context_docs: List[str]
) -> str:
"""
Construct a prompt with retrieved context
Args:
query (str): Original user query
context_docs (List[str]): Retrieved context documents
Returns:
str: Augmented prompt for the LLM
"""
context_str = "\n\n".join(context_docs)
return f"""
Context Information:
{context_str}
User Query: {query}
Based on the context, please provide a comprehensive and accurate response.
""" |