MCP_Track3_Discover / deploy_serverless_modal_image_processing.py
RCaz's picture
added video analyser modal
bd1111c
import modal
from PIL import Image
import torch
# --- Modal App Setup ---
# This section defines the environment and models for our serverless functions.
# The container image will have all necessary libraries pre-installed.
modal_app = modal.App("video-analysis-app")
image = modal.Image.debian_slim().pip_install(
"torch", "transformers", "Pillow"
)
# Define a class to load models only once when the container starts.
# This is a Modal best practice that avoids slow cold starts.
# We request a GPU to accelerate the model inference.
@modal_app.cls(gpu="any", image=image)
class VideoAnalyzer:
def __enter__(self):
"""
This method is called once when the container starts.
It loads the models and processors into memory, so they are ready
for immediate use by the functions.
"""
from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection
print("Loading Image Captioning model...")
# Model for describing the frame (image captioning)
self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda")
print("Image Captioning model loaded.")
print("Loading Object Detection model...")
# Model for detecting objects
self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
print("Object Detection model loaded.")
@modal.method()
def describe_frame(self, image_bytes: bytes) -> str:
"""
Takes image bytes as input and returns a generated text description.
"""
try:
# Open the image from the raw bytes
raw_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process the image and generate a caption
inputs = self.caption_processor(raw_image, return_tensors="pt").to("cuda")
out = self.caption_model.generate(**inputs, max_new_tokens=50)
# Decode the caption and return it
caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
return caption
except Exception as e:
print(f"Error describing frame: {e}")
return "Could not describe the image."
@modal.method()
def detect_objects(self, image_bytes: bytes, threshold: float = 0.9) -> list[str]:
"""
Takes image bytes as input and returns a list of detected object labels.
Only returns objects with a confidence score above the threshold.
"""
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Process the image for the object detection model
inputs = self.detection_processor(images=image, return_tensors="pt").to("cuda")
outputs = self.detection_model(**inputs)
# Post-process the results to get labels and scores
target_sizes = torch.tensor([image.size[::-1]])
results = self.detection_processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold
)[0]
# Extract the labels for confident detections
detected_objects = []
for score, label in zip(results["scores"], results["labels"]):
object_name = self.detection_model.config.id2label[label.item()]
detected_objects.append(object_name)
# Return a list of unique object names
return list(set(detected_objects))
except Exception as e:
print(f"Error detecting objects: {e}")
return []
# --- Local Runner ---
# This part of the script runs on your local machine to call the serverless functions.
# It reads a local image file and sends it to the Modal functions for processing.
@modal_app.local_entrypoint()
def main(image_path: str):
"""
A local entrypoint to test the Modal functions.
Usage:
modal run this_script_name.py --image-path /path/to/your/image.jpg
"""
import io
try:
with open(image_path, "rb") as f:
img_bytes = f.read()
except FileNotFoundError:
print(f"Error: Image file not found at '{image_path}'")
return
print("--- Calling Modal Functions ---")
# Instantiate the class, which will trigger the __enter__ method on the remote container
analyzer = VideoAnalyzer()
# Call the remote functions with the image data
description = analyzer.describe_frame.remote(img_bytes)
objects = analyzer.detect_objects.remote(img_bytes)
print(f"\n📸 Analysis for: {image_path}")
print("-----------------------------------")
print(f"📝 Description: {description}")
print(f"📦 Detected Objects: {objects}")
print("-----------------------------------")