venkyvicky's picture
Upload 2 files
8eed584 verified
raw
history blame
2.9 kB
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import gradio as gr
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
# Load processor and model from Hugging Face
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance")
model.eval()
# Load label map from model config
COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)]
def segment_image(image, threshold=0.5):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
segmentation_map = results["segmentation"].cpu().numpy() # shape: [H, W]
segments_info = results["segments_info"] # list of dicts with keys: id, label_id, score
image_np = np.array(image).copy()
overlay = image_np.copy()
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(image_np)
for segment in segments_info:
score = segment.get("score", 1.0)
if score < threshold:
continue
segment_id = segment["id"]
label_id = segment["label_id"]
mask = segmentation_map == segment_id
# Random color per object
color = np.random.rand(3)
overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8)
# Draw bounding box
y_indices, x_indices = np.where(mask)
if len(x_indices) == 0 or len(y_indices) == 0:
continue
x1, x2 = x_indices.min(), x_indices.max()
y1, y2 = y_indices.min(), y_indices.max()
label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id))
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2))
ax.text(x1, y1, f"{label_name}: {score:.2f}",
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10)
ax.imshow(overlay)
ax.axis('off')
output_path = "mask2former_output.png"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return output_path
# Gradio interface
interface = gr.Interface(
fn=segment_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
],
outputs=gr.Image(type="filepath", label="Segmented Output"),
title="Mask2Former Instance Segmentation (Transformer)",
description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)."
)
if __name__ == "__main__":
interface.launch(debug=True,share=True)