Spaces:
Runtime error
Runtime error
import cv2 | |
import os | |
import io | |
from PIL import Image | |
import torch | |
from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection | |
from collections import Counter | |
class VideoAnalyzer: | |
def __init__(self): | |
"""Initialize the models.""" | |
print("Loading Image Captioning model...") | |
self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
print("Loading Object Detection model...") | |
self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
def describe_frame(self, image_path: str) -> str: | |
"""Generate a text description of the frame.""" | |
try: | |
raw_image = Image.open(image_path).convert("RGB") | |
inputs = self.caption_processor(raw_image, return_tensors="pt") | |
out = self.caption_model.generate(**inputs, max_new_tokens=50) | |
caption = self.caption_processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
except Exception as e: | |
print(f"Error describing frame: {e}") | |
return "Could not describe the image." | |
def detect_objects(self, image_path: str, threshold: float = 0.9) -> list[str]: | |
"""Detect objects in the frame.""" | |
try: | |
image = Image.open(image_path).convert("RGB") | |
inputs = self.detection_processor(images=image, return_tensors="pt") | |
outputs = self.detection_model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = self.detection_processor.post_process_object_detection( | |
outputs, target_sizes=target_sizes, threshold=threshold | |
)[0] | |
detected_objects = [] | |
for score, label in zip(results["scores"], results["labels"]): | |
object_name = self.detection_model.config.id2label[label.item()] | |
detected_objects.append(object_name) | |
return dict(Counter(detected_objects)) | |
except Exception as e: | |
print(f"Error detecting objects: {e}") | |
return [] | |
# Initialize the VideoAnalyzer | |
analyzer = VideoAnalyzer() | |
def get_frame_infos(filename: str) -> dict: | |
"""Extract information from a frame.""" | |
if not os.path.exists(filename): | |
return {"error": "File not found"} | |
description = analyzer.describe_frame(filename) | |
objects = analyzer.detect_objects(filename) | |
return { | |
"filename": filename, | |
"description": description, | |
"objects": objects | |
} |