File size: 4,207 Bytes
a75702e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from chromadb import PersistentClient, EmbeddingFunction, Embeddings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from typing import List
import json


MODEL_NAME = 'dunzhang/stella_en_1.5B_v5'
DB_PATH = './.chroma_db'
FAQ_FILE_PATH= './data/FAQ.json'
INVENTORY_FILE_PATH = './data/inventory.json'

class Product:
    def __init__(self, name: str, id: str, description: str, type: str, price: float, quantity: int):
        self.name = name
        self.id = id
        self.description = description
        self.type = type
        self.price = price
        self.quantity = quantity

class QuestionAnswerPairs:
    def __init__(self, question: str, answer: str):
        self.question = question
        self.answer = answer

class CustomEmbeddingClass(EmbeddingFunction):
    def __init__(self, model_name):
        self.embedding_model = HuggingFaceEmbedding(model_name=MODEL_NAME)

    def __call__(self, input_texts: List[str]) -> Embeddings:
        return [self.embedding_model.get_text_embedding(text) for text in input_texts]

class FAQCollection:
    def __init__(self):
        self.documents = []
        self.ids = []
        self.metadatas = []

    def add(self, documents, ids, metadatas):
        self.documents.extend(documents)
        self.ids.extend(ids)
        self.metadatas.extend(metadatas)

    def display(self):
        for doc, id_, meta in zip(self.documents, self.ids, self.metadatas):
            print(f"ID: {id_}, Document: {doc}, Metadata: {meta}")

# Define the InventoryCollection class
class InventoryCollection:
    def __init__(self):
        self.documents = []
        self.ids = []
        self.metadatas = []

    def add(self, documents, ids, metadatas):
        self.documents.extend(documents)
        self.ids.extend(ids)
        self.metadatas.extend(metadatas)

    def display(self):
        for doc, id_, meta in zip(self.documents, self.ids, self.metadatas):
            print(f"ID: {id_}, Document: {doc}, Metadata: {meta}")




class FlowerShopVectorStore:
    def __init__(self):
        db = PersistentClient(path=DB_PATH)

        custom_embedding_function = CustomEmbeddingClass(MODEL_NAME)

        self.faq_collection = db.get_or_create_collection(name='FAQ', embedding_function=custom_embedding_function)
        self.inventory_collection = db.get_or_create_collection(name='Inventory', embedding_function=custom_embedding_function)

        if self.faq_collection.count() == 0:
            try :
                self._load_faq_collection(FAQ_FILE_PATH)
            except Exception as e:
                raise ValueError(e)
                

        if self.inventory_collection.count() == 0:
            self._load_inventory_collection(INVENTORY_FILE_PATH)

    def _load_faq_collection(self, faq_file_path: str):
        try:
            
            with open(faq_file_path, 'r') as f:
                faqs = json.load(f)
                
            # Create an instance of FAQCollection
            obj_faq_collection = FAQCollection()

            obj_faq_collection.add(
                documents=[faq['question'] for faq in faqs] + [faq['answer'] for faq in faqs],
                ids=[str(i) for i in range(0, 2*len(faqs))],
                metadatas = faqs + faqs
            )
            self.faq_collection = obj_faq_collection
        except Exception as ex:
            raise ValueError(ex)

    def _load_inventory_collection(self, inventory_file_path: str):
        with open(inventory_file_path, 'r') as f:
            inventories = json.load(f)
            
        # Create an instance of InventoryCollection
        obj_inventory_collection = InventoryCollection()

        obj_inventory_collection.add(
            documents=[inventory['description'] for inventory in inventories],
            ids=[str(i) for i in range(0, len(inventories))],
            metadatas = inventories
        )
        self.inventory_collection = obj_inventory_collection

    def query_faqs(self, query: str): 
        return self.faq_collection.query(query_texts=[query], n_results=5)
    
    def query_inventories(self, query: str):
        return self.inventory_collection.query(query_texts=[query], n_results=5)