File size: 2,796 Bytes
bd1111c
 
 
 
 
 
7e4124b
bd1111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e4124b
bd1111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import cv2
import os
import io
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection
from collections import Counter


class VideoAnalyzer:
    def __init__(self):
        """Initialize the models."""
        print("Loading Image Captioning model...")
        self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

        print("Loading Object Detection model...")
        self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
        self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

    def describe_frame(self, image_path: str) -> str:
        """Generate a text description of the frame."""
        try:
            raw_image = Image.open(image_path).convert("RGB")
            inputs = self.caption_processor(raw_image, return_tensors="pt")
            out = self.caption_model.generate(**inputs, max_new_tokens=50)
            caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
            return caption
        except Exception as e:
            print(f"Error describing frame: {e}")
            return "Could not describe the image."

    def detect_objects(self, image_path: str, threshold: float = 0.9) -> list[str]:
        """Detect objects in the frame."""
        try:
            image = Image.open(image_path).convert("RGB")
            inputs = self.detection_processor(images=image, return_tensors="pt")
            outputs = self.detection_model(**inputs)
            target_sizes = torch.tensor([image.size[::-1]])
            results = self.detection_processor.post_process_object_detection(
                outputs, target_sizes=target_sizes, threshold=threshold
            )[0]

            detected_objects = []
            for score, label in zip(results["scores"], results["labels"]):
                object_name = self.detection_model.config.id2label[label.item()]
                detected_objects.append(object_name)

            return dict(Counter(detected_objects))
        except Exception as e:
            print(f"Error detecting objects: {e}")
            return []

# Initialize the VideoAnalyzer
analyzer = VideoAnalyzer()

def get_frame_infos(filename: str) -> dict:
    """Extract information from a frame."""
    if not os.path.exists(filename):
        return {"error": "File not found"}

    description = analyzer.describe_frame(filename)
    objects = analyzer.detect_objects(filename)

    return {
        "filename": filename,
        "description": description,
        "objects": objects
    }