ScouterAI / remote_tools /rag_tool.py
stevenbucaille's picture
Add initial project structure with core functionality for image processing agents
7e327f2
raw
history blame
2.28 kB
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from smolagents import Tool
import modal
from .app import app
from .image import image
from .volume import volume
@app.cls(gpu="T4", image=image, volumes={"/volume": volume})
class RemoteObjectDetectionModelRetrieverModalApp:
@modal.enter()
def setup(self):
self.vector_store = FAISS.load_local(
folder_path="/volume/vector_store",
embeddings=HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": True},
show_progress=True,
),
index_name="object_detection_models_faiss_index",
allow_dangerous_deserialization=True,
)
@modal.method()
def forward(self, query: str) -> str:
docs = self.vector_store.similarity_search(query, k=7)
model_ids = [doc.metadata["model_id"] for doc in docs]
model_labels = [doc.metadata["model_labels"] for doc in docs]
models_dict = {
model_id: model_labels
for model_id, model_labels in zip(model_ids, model_labels)
}
return models_dict
class RemoteObjectDetectionModelRetrieverTool(Tool):
name = "object_detection_model_retriever"
description = """
For a given class of objects, retrieve the models that can detect that class.
The query is a string that describes the class of objects the model needs to detect.
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
"""
inputs = {
"query": {
"type": "string",
"description": "The class of objects the model needs to detect.",
}
}
output_type = "object"
def __init__(self):
super().__init__()
self.tool_class = modal.Cls.from_name(
app.name, RemoteObjectDetectionModelRetrieverModalApp.__name__
)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
tool = self.tool_class()
result = tool.forward.remote(query)
return result