Spaces:
Running
Running
File size: 6,404 Bytes
31add3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
Utilities for Retrieval Augmented Generation (RAG) setup and operations.
"""
import os
import chainlit as cl
from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader, PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings # Using HuggingFace for local, open-source model
from langchain_openai import OpenAIEmbeddings # Ensure this is imported
from langchain_qdrant import Qdrant
# from qdrant_client import QdrantClient # QdrantClient might be used for more direct/advanced Qdrant operations
# Recommended embedding model for good performance and local use.
# Ensure this model is downloaded or will be downloaded by sentence-transformers.
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
# Qdrant recommends using FastEmbed for optimal performance with their DB.
# from langchain_community.embeddings import FastEmbedEmbeddings
def load_and_split_documents(persona_id: str, data_path: str = "data_sources"):
"""
Loads all .txt files from the persona's data directory, splits them into
chunks, and returns a list of Langchain Document objects.
"""
persona_data_path = os.path.join(data_path, persona_id)
if not os.path.isdir(persona_data_path):
print(f"Warning: Data directory not found for persona {persona_id} at {persona_data_path}")
return []
# Using DirectoryLoader with UnstructuredFileLoader for .txt files
# It's good at handling basic text. For more complex .txt or other formats,
# you might need more specific loaders or pre-processing.
loader = DirectoryLoader(
persona_data_path,
glob="**/*.txt", # Load all .txt files, including in subdirectories
loader_cls=UnstructuredFileLoader,
show_progress=True,
use_multithreading=True,
silent_errors=True # Suppress errors for files it can't load, log them instead if needed
)
try:
loaded_documents = loader.load()
if not loaded_documents:
print(f"No documents found or loaded for persona {persona_id} from {persona_data_path}")
return []
except Exception as e:
print(f"Error loading documents for persona {persona_id}: {e}")
return []
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # Adjust as needed
chunk_overlap=150, # Adjust as needed
length_function=len,
is_separator_regex=False,
)
split_docs = text_splitter.split_documents(loaded_documents)
print(f"Loaded and split {len(loaded_documents)} documents into {len(split_docs)} chunks for persona {persona_id}.")
return split_docs
def get_embedding_model_instance(model_identifier: str, model_type: str = "hf"):
"""
Initializes and returns an embedding model instance.
model_type can be 'hf' for HuggingFace or 'openai'.
"""
if model_type == "openai":
print(f"Initializing OpenAI Embedding Model: {model_identifier}")
return OpenAIEmbeddings(model=model_identifier)
elif model_type == "hf":
print(f"Initializing HuggingFace Embedding Model: {model_identifier}")
# EMBEDDING_MODEL_NAME is still defined globally above, can be used as a fallback if needed
# but app.py now passes the explicit model_identifier.
return HuggingFaceEmbeddings(model_name=model_identifier)
else:
raise ValueError(f"Unsupported embedding model_type: {model_type}")
async def get_or_create_persona_vector_store(persona_id: str, embedding_model):
"""Gets a vector store for a persona from the session, or creates and caches it."""
vector_store_key = f"vector_store_{persona_id}"
if cl.user_session.get(vector_store_key) is not None:
print(f"Found existing vector store for {persona_id} in session.")
return cl.user_session.get(vector_store_key)
print(f"No existing vector store for {persona_id} in session. Creating new one...")
documents = load_and_split_documents(persona_id)
if not documents:
print(f"No documents to create vector store for persona {persona_id}.")
cl.user_session.set(vector_store_key, None) # Mark as attempted but failed
return None
try:
# Qdrant.from_documents will handle creating the client and collection in-memory
vector_store = await cl.make_async(Qdrant.from_documents)(
documents,
embedding_model,
location=":memory:", # Specifies in-memory Qdrant
collection_name=f"{persona_id}_store_{cl.user_session.get('id', 'default_session')}", # Unique per session
# distance_func="Cosine", # Qdrant default is Cosine, or use models.Distance.COSINE
prefer_grpc=False # For local/in-memory, GRPC might not be necessary or set up
)
print(f"Successfully created vector store for {persona_id} with {len(documents)} chunks.")
cl.user_session.set(vector_store_key, vector_store)
return vector_store
except Exception as e:
print(f"Error creating Qdrant vector store for {persona_id}: {e}")
cl.user_session.set(vector_store_key, None) # Mark as attempted but failed
return None
async def get_relevant_context_for_query(query: str, persona_id: str, embedding_model) -> str:
"""
Retrieves relevant context from the persona's vector store for a given query.
Returns a string of concatenated context or an empty string if no context found.
"""
vector_store = await get_or_create_persona_vector_store(persona_id, embedding_model)
if not vector_store:
print(f"No vector store available for {persona_id} to retrieve context.")
return ""
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) # Retrieve top 3 chunks
try:
results = await cl.make_async(retriever.invoke)(query)
if results:
context = "\n\n---\n\n".join([doc.page_content for doc in results])
# print(f"Retrieved context for {persona_id} query '{query}':\n{context[:500]}...") # Debug
return context
else:
print(f"No relevant documents found for {persona_id} query: '{query}'")
return ""
except Exception as e:
print(f"Error retrieving context for {persona_id} query '{query}': {e}")
return "" |