Spaces:
Running
on
Zero
Running
on
Zero
from langchain_community.vectorstores import FAISS | |
from smolagents import Tool | |
from rag.settings import get_vector_store | |
class ObjectDetectionModelRetrieverTool(Tool): | |
name = "object_detection_model_retriever" | |
description = """ | |
For a given class of objects, retrieve the models that can detect that class. | |
The query is a string that describes the class of objects the model needs to detect. | |
The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
""" | |
inputs = { | |
"query": { | |
"type": "object", | |
"description": "The class of objects the model needs to detect.", | |
} | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
def setup(self): | |
self.vector_store = get_vector_store() | |
print("Loaded vector store") | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
docs = self.vector_store.similarity_search(query, k=7) | |
model_ids = [doc.metadata["model_id"] for doc in docs] | |
model_labels = [doc.metadata["model_labels"] for doc in docs] | |
models_dict = { | |
model_id: model_labels | |
for model_id, model_labels in zip(model_ids, model_labels) | |
} | |
return models_dict | |