File size: 1,358 Bytes
7e327f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from langchain_community.vectorstores import FAISS
from smolagents import Tool
from rag.settings import get_vector_store


class ObjectDetectionModelRetrieverTool(Tool):
    name = "object_detection_model_retriever"
    description = """
    For a given class of objects, retrieve the models that can detect that class.
    The query is a string that describes the class of objects the model needs to detect.
    The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
    """
    inputs = {
        "query": {
            "type": "object",
            "description": "The class of objects the model needs to detect.",
        }
    }
    output_type = "object"

    def __init__(self):
        super().__init__()

    def setup(self):
        self.vector_store = get_vector_store()
        print("Loaded vector store")

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        docs = self.vector_store.similarity_search(query, k=7)
        model_ids = [doc.metadata["model_id"] for doc in docs]
        model_labels = [doc.metadata["model_labels"] for doc in docs]
        models_dict = {
            model_id: model_labels
            for model_id, model_labels in zip(model_ids, model_labels)
        }
        return models_dict