ScouterAI / remote_tools /object_detection_tool.py
stevenbucaille's picture
Add initial project structure with core functionality for image processing agents
7e327f2
raw
history blame
2.82 kB
import modal
from transformers import AutoModelForObjectDetection, AutoImageProcessor
import torch
from smolagents import Tool
from .app import app
from .image import image
@app.cls(gpu="T4", image=image)
class RemoteObjectDetectionModalApp:
model_name: str = modal.parameter()
@modal.method()
def forward(self, image):
self.model = AutoModelForObjectDetection.from_pretrained(self.model_name)
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
self.model.eval()
# Preprocess image
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
results = self.processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.5
)[0]
boxes = []
for score, label, box in zip(
results["scores"], results["labels"], results["boxes"]
):
boxes.append(
{
"box": box.tolist(), # [xmin, ymin, xmax, ymax]
"score": score.item(),
"label": self.model.config.id2label[label.item()],
}
)
return boxes
class RemoteObjectDetectionTool(Tool):
name = "object_detection"
description = """
Given an image, detect objects and return bounding boxes.
The image is a PIL image.
The output is a list of dictionaries containing the bounding boxes with the following keys:
- box: a list of 4 numbers [xmin, ymin, xmax, ymax]
- score: a number between 0 and 1
- label: a string
The bounding boxes are in the format of [xmin, ymin, xmax, ymax].
You need to provide the model name to use for object detection.
The tool returns a list of bounding boxes for all the objects in the image.
"""
inputs = {
"image": {
"type": "image",
"description": "The image to detect objects in",
},
"model_name": {
"type": "string",
"description": "The name of the model to use for object detection",
},
}
output_type = "object"
def __init__(self):
super().__init__()
self.tool_class = modal.Cls.from_name(
app.name, RemoteObjectDetectionModalApp.__name__
)
def forward(
self,
image,
model_name: str,
):
self.tool = self.tool_class(model_name=model_name)
bboxes = self.tool.forward.remote(image)
for bbox in bboxes:
print(
f"Found {bbox['label']} with score: {bbox['score']} at box: {bbox['box']}"
)
return bboxes