Víctor Sáez
Update Gradio interface and add arial.ttf tracked via LFS
6ecfb14
raw
history blame
2.65 kB
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import DetrImageProcessor, DetrForObjectDetection
from pathlib import Path
# Load DETR model and processor from Hugging Face
model_name = "facebook/detr-resnet-50"
processor = DetrImageProcessor.from_pretrained(model_name)
model = DetrForObjectDetection.from_pretrained(model_name)
# Load font
font_path = Path("assets/fonts/arial.ttf")
if not font_path.exists():
# If the font file does not exist, use the default PIL font
print(f"Font file {font_path} not found. Using default font.")
font = ImageFont.load_default()
else:
font = ImageFont.truetype(str(font_path), size=100)
print(f"CUDA is available: {torch.cuda.is_available()}")
# Main function: takes an image and returns it with boxes and labels
def detect_objects(image):
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Convert model output to usable detection results
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, threshold=0.9, target_sizes=target_sizes
)[0]
# Draw bounding boxes and labels on a copy of the image
image_with_boxes = image.copy()
draw = ImageDraw.Draw(image_with_boxes)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(x, 2) for x in box.tolist()]
draw.rectangle(box, outline="red", width=3)
# Prepare label text
label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
# Measure text size
text_bbox = draw.textbbox((0, 0), label_text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Set background rectangle for text
text_background = [
box[0], box[1] - text_height,
box[0] + text_width, box[1]
]
draw.rectangle(text_background, fill="black") # Background
draw.text((box[0], box[1] - text_height), label_text, fill="white", font=font)
return image_with_boxes
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("## Object Detection App\nUpload an image to detect objects using Facebook's DETR model.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
output_image = gr.Image(label="Detected Objects")
with gr.Row():
button = gr.Button("Detect Objects")
button.click(fn=detect_objects, inputs=input_image, outputs=output_image)
if __name__ == "__main__":
app.launch()