Spaces:
Runtime error
Runtime error
File size: 1,358 Bytes
7e327f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from langchain_community.vectorstores import FAISS
from smolagents import Tool
from rag.settings import get_vector_store
class ObjectDetectionModelRetrieverTool(Tool):
name = "object_detection_model_retriever"
description = """
For a given class of objects, retrieve the models that can detect that class.
The query is a string that describes the class of objects the model needs to detect.
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
"""
inputs = {
"query": {
"type": "object",
"description": "The class of objects the model needs to detect.",
}
}
output_type = "object"
def __init__(self):
super().__init__()
def setup(self):
self.vector_store = get_vector_store()
print("Loaded vector store")
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.vector_store.similarity_search(query, k=7)
model_ids = [doc.metadata["model_id"] for doc in docs]
model_labels = [doc.metadata["model_labels"] for doc in docs]
models_dict = {
model_id: model_labels
for model_id, model_labels in zip(model_ids, model_labels)
}
return models_dict
|