File size: 3,319 Bytes
a2682b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc4aec6
a2682b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from typing import List, Dict, Optional
import vecs
from datetime import datetime

class DatabaseService:
    def __init__(self):
        # Connection parameters
        self.DB_HOST = os.getenv("SUPABASE_HOST", "db.daxquaudqidyeirypexa.supabase.co")
        self.DB_PORT = os.getenv("DB_PORT", "5432")
        self.DB_NAME = os.getenv("DB_NAME", "postgres")
        self.DB_USER = os.getenv("DB_USER", "postgres")
        self.DB_PASSWORD = os.getenv("DB_PASSWORD", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImRheHF1YXVkcWlkeWVpcnlwZXhhIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDQzOTIzNzcsImV4cCI6MjA1OTk2ODM3N30.3qB-GfiCoqXEpbNfqV3iHiqOLr8Ex9nPVr6p9De5Hdc")
        
        # Create vecs client
        self.vx = vecs.create_client(
            f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
        )
        
        # Get or create the collection
        self.articles = self.vx.get_or_create_collection(
            name="articles",
            dimension=384  # Match your embedding model's output dimension
        )

    async def semantic_search(
        self,
        query_embedding: List[float],
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        topic: Optional[str] = None,
        entities: Optional[List[str]] = None,  # Add entities parameter
        limit: int = 10
    ) -> List[Dict[str, any]]:
        try:
            # Base vector search
            filters = self._build_filters(start_date, end_date, topic)
            
            # Add entity filter if entities are provided
            if entities:
                filters["entities"] = {"$in": entities}
            
            results = self.articles.query(
                data=query_embedding,
                limit=limit,
                filters=filters,
                measure="cosine_distance"  # or "inner_product", "l2_distance"
            )
            
            # Format results with metadata
            formatted_results = []
            for article_id, distance in results:
                metadata = self.articles.fetch(ids=[article_id])[0]["metadata"]
                formatted_results.append({
                    "id": article_id,
                    "url": metadata.get("url"),
                    "content": metadata.get("content"),
                    "date": metadata.get("date"),
                    "topic": metadata.get("topic"),
                    "distance": float(distance),
                    "similarity": 1 - float(distance)  # Convert to similarity score
                })
            
            return formatted_results
        
        except Exception as e:
            print(f"Vector search error: {e}")
            return []

    def _build_filters(
        self,
        start_date: Optional[datetime],
        end_date: Optional[datetime],
        topic: Optional[str]
    ) -> Dict[str, any]:
        filters = {}
        
        if start_date and end_date:
            filters["date"] = {
                "$gte": start_date.isoformat(),
                "$lte": end_date.isoformat()
            }
        
        if topic:
            filters["topic"] = {"$eq": topic}
            
        return filters

    async def close(self):
        self.vx.disconnect()