File size: 2,278 Bytes
7e327f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from smolagents import Tool
import modal

from .app import app
from .image import image
from .volume import volume


@app.cls(gpu="T4", image=image, volumes={"/volume": volume})
class RemoteObjectDetectionModelRetrieverModalApp:
    @modal.enter()
    def setup(self):
        self.vector_store = FAISS.load_local(
            folder_path="/volume/vector_store",
            embeddings=HuggingFaceEmbeddings(
                model_name="all-MiniLM-L6-v2",
                model_kwargs={"device": "cuda"},
                encode_kwargs={"normalize_embeddings": True},
                show_progress=True,
            ),
            index_name="object_detection_models_faiss_index",
            allow_dangerous_deserialization=True,
        )

    @modal.method()
    def forward(self, query: str) -> str:
        docs = self.vector_store.similarity_search(query, k=7)
        model_ids = [doc.metadata["model_id"] for doc in docs]
        model_labels = [doc.metadata["model_labels"] for doc in docs]
        models_dict = {
            model_id: model_labels
            for model_id, model_labels in zip(model_ids, model_labels)
        }
        return models_dict


class RemoteObjectDetectionModelRetrieverTool(Tool):
    name = "object_detection_model_retriever"
    description = """
    For a given class of objects, retrieve the models that can detect that class.
    The query is a string that describes the class of objects the model needs to detect.
    The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
    """
    inputs = {
        "query": {
            "type": "string",
            "description": "The class of objects the model needs to detect.",
        }
    }
    output_type = "object"

    def __init__(self):
        super().__init__()
        self.tool_class = modal.Cls.from_name(
            app.name, RemoteObjectDetectionModelRetrieverModalApp.__name__
        )

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        tool = self.tool_class()
        result = tool.forward.remote(query)
        return result