# src/agents/rag_agent.py from typing import List, Optional, Tuple, Dict import uuid from ..llms.base_llm import BaseLLM from src.embeddings.base_embedding import BaseEmbedding from src.vectorstores.base_vectorstore import BaseVectorStore from src.utils.conversation_manager import ConversationManager from src.db.mongodb_store import MongoDBStore from src.models.rag import RAGResponse from src.utils.logger import logger class RAGAgent: def __init__( self, llm: BaseLLM, embedding: BaseEmbedding, vector_store: BaseVectorStore, mongodb: MongoDBStore, max_history_tokens: int = 4000, max_history_messages: int = 10 ): """ Initialize RAG Agent Args: llm (BaseLLM): Language model instance embedding (BaseEmbedding): Embedding model instance vector_store (BaseVectorStore): Vector store instance mongodb (MongoDBStore): MongoDB store instance max_history_tokens (int): Maximum tokens in conversation history max_history_messages (int): Maximum messages to keep in history """ self.llm = llm self.embedding = embedding self.vector_store = vector_store self.mongodb = mongodb self.conversation_manager = ConversationManager( max_tokens=max_history_tokens, max_messages=max_history_messages ) async def generate_response( self, query: str, conversation_id: Optional[str] = None, temperature: float = 0.7, max_tokens: Optional[int] = None, context_docs: Optional[List[str]] = None ) -> RAGResponse: """Generate a response using RAG with conversation history""" try: # Create new conversation if no ID provided if not conversation_id: conversation_id = str(uuid.uuid4()) await self.mongodb.create_conversation(conversation_id) # Get conversation history history = await self.mongodb.get_recent_messages( conversation_id, limit=self.conversation_manager.max_messages ) # Get relevant history within token limits relevant_history = self.conversation_manager.get_relevant_history( messages=history, current_query=query ) if history else [] # Retrieve context if not provided if not context_docs: context_docs, sources, scores = await self.retrieve_context( query, conversation_history=relevant_history ) else: sources = None scores = None # Generate prompt with context and history augmented_prompt = self.conversation_manager.generate_prompt_with_history( current_query=query, history=relevant_history, context_docs=context_docs ) # Generate response using LLM response = self.llm.generate( augmented_prompt, temperature=temperature, max_tokens=max_tokens ) return RAGResponse( response=response, context_docs=context_docs, sources=sources, scores=scores ) except Exception as e: logger.error(f"Error generating response: {str(e)}") raise async def retrieve_context( self, query: str, conversation_history: Optional[List[Dict]] = None, top_k: int = 3 ) -> Tuple[List[str], List[Dict], Optional[List[float]]]: """ Retrieve context with conversation history enhancement Args: query (str): Current query conversation_history (Optional[List[Dict]]): Recent conversation history top_k (int): Number of documents to retrieve Returns: Tuple[List[str], List[Dict], Optional[List[float]]]: Retrieved documents, sources, and scores """ # Enhance query with conversation history if conversation_history: recent_queries = [ msg['query'] for msg in conversation_history[-2:] if msg.get('query') ] enhanced_query = " ".join([*recent_queries, query]) else: enhanced_query = query # Embed the enhanced query query_embedding = self.embedding.embed_query(enhanced_query) # Retrieve similar documents results = self.vector_store.similarity_search( query_embedding, top_k=top_k ) # Process results documents = [doc['text'] for doc in results] sources = [self._convert_metadata_to_strings(doc['metadata']) for doc in results] scores = [doc['score'] for doc in results if doc.get('score') is not None] # Return scores only if available for all documents if len(scores) != len(documents): scores = None return documents, sources, scores def _convert_metadata_to_strings(self, metadata: Dict) -> Dict: """Convert numeric metadata values to strings""" converted = {} for key, value in metadata.items(): if isinstance(value, (int, float)): converted[key] = str(value) else: converted[key] = value return converted