from langchain_community.vectorstores import FAISS from smolagents import Tool from rag.settings import get_vector_store class ObjectDetectionModelRetrieverTool(Tool): name = "object_detection_model_retriever" description = """ For a given class of objects, retrieve the models that can detect that class. The query is a string that describes the class of objects the model needs to detect. The output is a dictionary with the model id as the key and the labels that the model can detect as the value. """ inputs = { "query": { "type": "object", "description": "The class of objects the model needs to detect.", } } output_type = "object" def __init__(self): super().__init__() def setup(self): self.vector_store = get_vector_store() print("Loaded vector store") def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" docs = self.vector_store.similarity_search(query, k=7) model_ids = [doc.metadata["model_id"] for doc in docs] model_labels = [doc.metadata["model_labels"] for doc in docs] models_dict = { model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels) } return models_dict