Spaces:
Runtime error
Runtime error
import modal | |
from PIL import Image | |
import torch | |
# --- Modal App Setup --- | |
# This section defines the environment and models for our serverless functions. | |
# The container image will have all necessary libraries pre-installed. | |
modal_app = modal.App("video-analysis-app") | |
image = modal.Image.debian_slim().pip_install( | |
"torch", "transformers", "Pillow" | |
) | |
# Define a class to load models only once when the container starts. | |
# This is a Modal best practice that avoids slow cold starts. | |
# We request a GPU to accelerate the model inference. | |
class VideoAnalyzer: | |
def __enter__(self): | |
""" | |
This method is called once when the container starts. | |
It loads the models and processors into memory, so they are ready | |
for immediate use by the functions. | |
""" | |
from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection | |
print("Loading Image Captioning model...") | |
# Model for describing the frame (image captioning) | |
self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda") | |
print("Image Captioning model loaded.") | |
print("Loading Object Detection model...") | |
# Model for detecting objects | |
self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda") | |
print("Object Detection model loaded.") | |
def describe_frame(self, image_bytes: bytes) -> str: | |
""" | |
Takes image bytes as input and returns a generated text description. | |
""" | |
try: | |
# Open the image from the raw bytes | |
raw_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
# Process the image and generate a caption | |
inputs = self.caption_processor(raw_image, return_tensors="pt").to("cuda") | |
out = self.caption_model.generate(**inputs, max_new_tokens=50) | |
# Decode the caption and return it | |
caption = self.caption_processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
except Exception as e: | |
print(f"Error describing frame: {e}") | |
return "Could not describe the image." | |
def detect_objects(self, image_bytes: bytes, threshold: float = 0.9) -> list[str]: | |
""" | |
Takes image bytes as input and returns a list of detected object labels. | |
Only returns objects with a confidence score above the threshold. | |
""" | |
try: | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
# Process the image for the object detection model | |
inputs = self.detection_processor(images=image, return_tensors="pt").to("cuda") | |
outputs = self.detection_model(**inputs) | |
# Post-process the results to get labels and scores | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = self.detection_processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=threshold | |
)[0] | |
# Extract the labels for confident detections | |
detected_objects = [] | |
for score, label in zip(results["scores"], results["labels"]): | |
object_name = self.detection_model.config.id2label[label.item()] | |
detected_objects.append(object_name) | |
# Return a list of unique object names | |
return list(set(detected_objects)) | |
except Exception as e: | |
print(f"Error detecting objects: {e}") | |
return [] | |
# --- Local Runner --- | |
# This part of the script runs on your local machine to call the serverless functions. | |
# It reads a local image file and sends it to the Modal functions for processing. | |
def main(image_path: str): | |
""" | |
A local entrypoint to test the Modal functions. | |
Usage: | |
modal run this_script_name.py --image-path /path/to/your/image.jpg | |
""" | |
import io | |
try: | |
with open(image_path, "rb") as f: | |
img_bytes = f.read() | |
except FileNotFoundError: | |
print(f"Error: Image file not found at '{image_path}'") | |
return | |
print("--- Calling Modal Functions ---") | |
# Instantiate the class, which will trigger the __enter__ method on the remote container | |
analyzer = VideoAnalyzer() | |
# Call the remote functions with the image data | |
description = analyzer.describe_frame.remote(img_bytes) | |
objects = analyzer.detect_objects.remote(img_bytes) | |
print(f"\n📸 Analysis for: {image_path}") | |
print("-----------------------------------") | |
print(f"📝 Description: {description}") | |
print(f"📦 Detected Objects: {objects}") | |
print("-----------------------------------") | |