from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from smolagents import Tool import modal from .app import app from .image import image from .volume import volume @app.cls(gpu="T4", image=image, volumes={"/volume": volume}) class RemoteObjectDetectionModelRetrieverModalApp: @modal.enter() def setup(self): self.vector_store = FAISS.load_local( folder_path="/volume/vector_store", embeddings=HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2", model_kwargs={"device": "cuda"}, encode_kwargs={"normalize_embeddings": True}, show_progress=True, ), index_name="object_detection_models_faiss_index", allow_dangerous_deserialization=True, ) @modal.method() def forward(self, query: str) -> str: docs = self.vector_store.similarity_search(query, k=7) model_ids = [doc.metadata["model_id"] for doc in docs] model_labels = [doc.metadata["model_labels"] for doc in docs] models_dict = { model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels) } return models_dict class RemoteObjectDetectionModelRetrieverTool(Tool): name = "object_detection_model_retriever" description = """ For a given class of objects, retrieve the models that can detect that class. The query is a string that describes the class of objects the model needs to detect. The output is a dictionary with the model id as the key and the labels that the model can detect as the value. """ inputs = { "query": { "type": "string", "description": "The class of objects the model needs to detect.", } } output_type = "object" def __init__(self): super().__init__() self.tool_class = modal.Cls.from_name( app.name, RemoteObjectDetectionModelRetrieverModalApp.__name__ ) def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" tool = self.tool_class() result = tool.forward.remote(query) return result