MCP_Track3_Discover / load_vision_model_locally.py
RCaz's picture
added transcription
7e4124b
import cv2
import os
import io
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection
from collections import Counter
class VideoAnalyzer:
def __init__(self):
"""Initialize the models."""
print("Loading Image Captioning model...")
self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
print("Loading Object Detection model...")
self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
def describe_frame(self, image_path: str) -> str:
"""Generate a text description of the frame."""
try:
raw_image = Image.open(image_path).convert("RGB")
inputs = self.caption_processor(raw_image, return_tensors="pt")
out = self.caption_model.generate(**inputs, max_new_tokens=50)
caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
return caption
except Exception as e:
print(f"Error describing frame: {e}")
return "Could not describe the image."
def detect_objects(self, image_path: str, threshold: float = 0.9) -> list[str]:
"""Detect objects in the frame."""
try:
image = Image.open(image_path).convert("RGB")
inputs = self.detection_processor(images=image, return_tensors="pt")
outputs = self.detection_model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = self.detection_processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold
)[0]
detected_objects = []
for score, label in zip(results["scores"], results["labels"]):
object_name = self.detection_model.config.id2label[label.item()]
detected_objects.append(object_name)
return dict(Counter(detected_objects))
except Exception as e:
print(f"Error detecting objects: {e}")
return []
# Initialize the VideoAnalyzer
analyzer = VideoAnalyzer()
def get_frame_infos(filename: str) -> dict:
"""Extract information from a frame."""
if not os.path.exists(filename):
return {"error": "File not found"}
description = analyzer.describe_frame(filename)
objects = analyzer.detect_objects(filename)
return {
"filename": filename,
"description": description,
"objects": objects
}