File size: 2,074 Bytes
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# src/vectorstores/chroma_vectorstore.py
import chromadb
from typing import List, Callable, Any

from .base_vectorstore import BaseVectorStore

class ChromaVectorStore(BaseVectorStore):
    def __init__(
        self, 
        embedding_function: Callable[[List[str]], List[List[float]]], 
        persist_directory: str = './chroma_db'
    ):
        """
        Initialize Chroma Vector Store
        
        Args:
            embedding_function (Callable): Function to generate embeddings
            persist_directory (str): Directory to persist the vector store
        """
        self.client = chromadb.PersistentClient(path=persist_directory)
        self.collection = self.client.get_or_create_collection(name="documents")
        self.embedding_function = embedding_function
    
    def add_documents(
        self, 
        documents: List[str], 
        embeddings: List[List[float]] = None
    ) -> None:
        """
        Add documents to the vector store
        
        Args:
            documents (List[str]): List of document texts
            embeddings (List[List[float]], optional): Pre-computed embeddings
        """
        if not embeddings:
            embeddings = self.embedding_function(documents)
        
        # Generate unique IDs
        ids = [f"doc_{i}" for i in range(len(documents))]
        
        self.collection.add(
            documents=documents,
            embeddings=embeddings,
            ids=ids
        )
    
    def similarity_search(
        self, 
        query_embedding: List[float], 
        top_k: int = 3
    ) -> List[str]:
        """
        Perform similarity search
        
        Args:
            query_embedding (List[float]): Embedding of the query
            top_k (int): Number of top similar documents to retrieve
        
        Returns:
            List[str]: List of most similar documents
        """
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k
        )
        
        return results.get('documents', [[]])[0]