Spaces:
Sleeping
Sleeping
File size: 2,212 Bytes
0eb636f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import faiss
import json
import numpy as np
from pathlib import Path
from src.utils.config import VECTOR_DB_PATH, EMBEDDING_DIM
from typing import List
class VectorStore:
"""
Wrapper for FAISS vector storage, with ID-to-text mapping.
"""
def __init__(self, index_path: Path = VECTOR_DB_PATH):
self.index_path = index_path.with_suffix(".index")
self.meta_path = index_path.with_suffix(".json")
self.index = faiss.IndexFlatL2(EMBEDDING_DIM)
self.metadata = [] # list of dicts: {"id": str, "text": str}
# Try loading if exists
if self.index_path.exists() and self.meta_path.exists():
try:
self.load()
except Exception as e:
print(f"[WARN] Failed to load vector store: {e}")
# Reinitialize clean if corrupted
self.index = faiss.IndexFlatL2(EMBEDDING_DIM)
self.metadata = []
def add(self, embeddings: list[list[float]], metadata: List[dict]):
"""
Add new embeddings and their metadata (e.g., {"id": "doc1_chunk0", "text": "..."})
"""
self.index.add(np.array(embeddings).astype("float32"))
self.metadata.extend(metadata)
self.save()
def search(self, query_embedding: list[float], top_k: int = 5) -> List[dict]:
"""
Perform vector search and return metadata of top_k results.
"""
D, I = self.index.search(np.array([query_embedding]).astype("float32"), top_k)
return [self.metadata[i] for i in I[0]]
def save(self) -> None:
"""
Save data to an external file.
"""
self.index_path.parent.mkdir(parents = True, exist_ok = True)
faiss.write_index(self.index, str(self.index_path))
with open(self.meta_path, 'w', encoding = "utf-8") as f:
json.dump(self.metadata, f, ensure_ascii = False, indent = 2)
def load(self) -> None:
"""
Load data from an external file.
"""
self.index = faiss.read_index(str(self.index_path))
with open(self.meta_path, 'r', encoding = "utf-8") as f:
self.metadata = json.load(f)
|