Spaces:
Running
on
Zero
Running
on
Zero
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from smolagents import Tool | |
import modal | |
from .app import app | |
from .image import image | |
from .volume import volume | |
class RemoteObjectDetectionModelRetrieverModalApp: | |
def setup(self): | |
self.vector_store = FAISS.load_local( | |
folder_path="/volume/vector_store", | |
embeddings=HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={"device": "cuda"}, | |
encode_kwargs={"normalize_embeddings": True}, | |
show_progress=True, | |
), | |
index_name="object_detection_models_faiss_index", | |
allow_dangerous_deserialization=True, | |
) | |
def forward(self, query: str) -> str: | |
docs = self.vector_store.similarity_search(query, k=7) | |
model_ids = [doc.metadata["model_id"] for doc in docs] | |
model_labels = [doc.metadata["model_labels"] for doc in docs] | |
models_dict = { | |
model_id: model_labels | |
for model_id, model_labels in zip(model_ids, model_labels) | |
} | |
return models_dict | |
class RemoteObjectDetectionModelRetrieverTool(Tool): | |
name = "object_detection_model_retriever" | |
description = """ | |
For a given class of objects, retrieve the models that can detect that class. | |
The query is a string that describes the class of objects the model needs to detect. | |
The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
""" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The class of objects the model needs to detect.", | |
} | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
self.tool_class = modal.Cls.from_name( | |
app.name, RemoteObjectDetectionModelRetrieverModalApp.__name__ | |
) | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
tool = self.tool_class() | |
result = tool.forward.remote(query) | |
return result | |