RCaz commited on
Commit
bd1111c
·
1 Parent(s): 3255474

added video analyser modal

Browse files
deploy_serverless_modal_image_processing.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+ from PIL import Image
3
+ import torch
4
+
5
+ # --- Modal App Setup ---
6
+ # This section defines the environment and models for our serverless functions.
7
+ # The container image will have all necessary libraries pre-installed.
8
+ modal_app = modal.App("video-analysis-app")
9
+ image = modal.Image.debian_slim().pip_install(
10
+ "torch", "transformers", "Pillow"
11
+ )
12
+
13
+ # Define a class to load models only once when the container starts.
14
+ # This is a Modal best practice that avoids slow cold starts.
15
+ # We request a GPU to accelerate the model inference.
16
+ @modal_app.cls(gpu="any", image=image)
17
+ class VideoAnalyzer:
18
+ def __enter__(self):
19
+ """
20
+ This method is called once when the container starts.
21
+ It loads the models and processors into memory, so they are ready
22
+ for immediate use by the functions.
23
+ """
24
+ from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection
25
+
26
+ print("Loading Image Captioning model...")
27
+ # Model for describing the frame (image captioning)
28
+ self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
29
+ self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to("cuda")
30
+ print("Image Captioning model loaded.")
31
+
32
+ print("Loading Object Detection model...")
33
+ # Model for detecting objects
34
+ self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
35
+ self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
36
+ print("Object Detection model loaded.")
37
+
38
+ @modal.method()
39
+ def describe_frame(self, image_bytes: bytes) -> str:
40
+ """
41
+ Takes image bytes as input and returns a generated text description.
42
+ """
43
+ try:
44
+ # Open the image from the raw bytes
45
+ raw_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
46
+
47
+ # Process the image and generate a caption
48
+ inputs = self.caption_processor(raw_image, return_tensors="pt").to("cuda")
49
+ out = self.caption_model.generate(**inputs, max_new_tokens=50)
50
+
51
+ # Decode the caption and return it
52
+ caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
53
+ return caption
54
+ except Exception as e:
55
+ print(f"Error describing frame: {e}")
56
+ return "Could not describe the image."
57
+
58
+ @modal.method()
59
+ def detect_objects(self, image_bytes: bytes, threshold: float = 0.9) -> list[str]:
60
+ """
61
+ Takes image bytes as input and returns a list of detected object labels.
62
+ Only returns objects with a confidence score above the threshold.
63
+ """
64
+ try:
65
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
66
+
67
+ # Process the image for the object detection model
68
+ inputs = self.detection_processor(images=image, return_tensors="pt").to("cuda")
69
+ outputs = self.detection_model(**inputs)
70
+
71
+ # Post-process the results to get labels and scores
72
+ target_sizes = torch.tensor([image.size[::-1]])
73
+ results = self.detection_processor.post_process_object_detection(
74
+ outputs, target_sizes=target_sizes, threshold=threshold
75
+ )[0]
76
+
77
+ # Extract the labels for confident detections
78
+ detected_objects = []
79
+ for score, label in zip(results["scores"], results["labels"]):
80
+ object_name = self.detection_model.config.id2label[label.item()]
81
+ detected_objects.append(object_name)
82
+
83
+ # Return a list of unique object names
84
+ return list(set(detected_objects))
85
+ except Exception as e:
86
+ print(f"Error detecting objects: {e}")
87
+ return []
88
+
89
+ # --- Local Runner ---
90
+ # This part of the script runs on your local machine to call the serverless functions.
91
+ # It reads a local image file and sends it to the Modal functions for processing.
92
+ @modal_app.local_entrypoint()
93
+ def main(image_path: str):
94
+ """
95
+ A local entrypoint to test the Modal functions.
96
+
97
+ Usage:
98
+ modal run this_script_name.py --image-path /path/to/your/image.jpg
99
+ """
100
+ import io
101
+
102
+ try:
103
+ with open(image_path, "rb") as f:
104
+ img_bytes = f.read()
105
+ except FileNotFoundError:
106
+ print(f"Error: Image file not found at '{image_path}'")
107
+ return
108
+
109
+ print("--- Calling Modal Functions ---")
110
+
111
+ # Instantiate the class, which will trigger the __enter__ method on the remote container
112
+ analyzer = VideoAnalyzer()
113
+
114
+ # Call the remote functions with the image data
115
+ description = analyzer.describe_frame.remote(img_bytes)
116
+ objects = analyzer.detect_objects.remote(img_bytes)
117
+
118
+ print(f"\n📸 Analysis for: {image_path}")
119
+ print("-----------------------------------")
120
+ print(f"📝 Description: {description}")
121
+ print(f"📦 Detected Objects: {objects}")
122
+ print("-----------------------------------")
123
+
load_vision_model_locally.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import io
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration, DetrImageProcessor, DetrForObjectDetection
7
+
8
+
9
+
10
+ class VideoAnalyzer:
11
+ def __init__(self):
12
+ """Initialize the models."""
13
+ print("Loading Image Captioning model...")
14
+ self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
15
+ self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
16
+
17
+ print("Loading Object Detection model...")
18
+ self.detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
19
+ self.detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
20
+
21
+ def describe_frame(self, image_path: str) -> str:
22
+ """Generate a text description of the frame."""
23
+ try:
24
+ raw_image = Image.open(image_path).convert("RGB")
25
+ inputs = self.caption_processor(raw_image, return_tensors="pt")
26
+ out = self.caption_model.generate(**inputs, max_new_tokens=50)
27
+ caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
28
+ return caption
29
+ except Exception as e:
30
+ print(f"Error describing frame: {e}")
31
+ return "Could not describe the image."
32
+
33
+ def detect_objects(self, image_path: str, threshold: float = 0.9) -> list[str]:
34
+ """Detect objects in the frame."""
35
+ try:
36
+ image = Image.open(image_path).convert("RGB")
37
+ inputs = self.detection_processor(images=image, return_tensors="pt")
38
+ outputs = self.detection_model(**inputs)
39
+ target_sizes = torch.tensor([image.size[::-1]])
40
+ results = self.detection_processor.post_process_object_detection(
41
+ outputs, target_sizes=target_sizes, threshold=threshold
42
+ )[0]
43
+
44
+ detected_objects = []
45
+ for score, label in zip(results["scores"], results["labels"]):
46
+ object_name = self.detection_model.config.id2label[label.item()]
47
+ detected_objects.append(object_name)
48
+
49
+ return list(set(detected_objects))
50
+ except Exception as e:
51
+ print(f"Error detecting objects: {e}")
52
+ return []
53
+
54
+ # Initialize the VideoAnalyzer
55
+ analyzer = VideoAnalyzer()
56
+
57
+ def get_frame_infos(filename: str) -> dict:
58
+ """Extract information from a frame."""
59
+ if not os.path.exists(filename):
60
+ return {"error": "File not found"}
61
+
62
+ description = analyzer.describe_frame(filename)
63
+ objects = analyzer.detect_objects(filename)
64
+
65
+ return {
66
+ "filename": filename,
67
+ "description": description,
68
+ "objects": objects
69
+ }
utils.py CHANGED
@@ -124,3 +124,56 @@ def extract_keyframes(video_path, diff_threshold=0.4):
124
  cap.release()
125
  print(f"Extracted {saved_id} key frames.")
126
  return "success"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  cap.release()
125
  print(f"Extracted {saved_id} key frames.")
126
  return "success"
127
+
128
+
129
+ def extract_nfps_frames(video_path, nfps=5,diff_threshold=0.4):
130
+ """Extract 1 frame per second from a video.
131
+ Args:
132
+ video_path (str): Path to the input video file.
133
+ """
134
+ cap = cv2.VideoCapture(video_path)
135
+ if not cap.isOpened():
136
+ print("Failed to read video.")
137
+ return
138
+
139
+ output_path = '/tmp/video/frames'
140
+ os.makedirs(output_path, exist_ok=True)
141
+
142
+ fps = cap.get(cv2.CAP_PROP_FPS)
143
+ frame_interval = int(fps) * nfps # Capture one frame every n second
144
+
145
+ frame_id = 0
146
+ saved_id = 0
147
+ success, prev_frame = cap.read()
148
+
149
+ while True:
150
+ success, frame = cap.read()
151
+ if not success:
152
+ break
153
+
154
+ if frame_id % frame_interval == 0 and is_significantly_different(prev_frame, frame, threshold=diff_threshold):
155
+ filename = os.path.join(output_path, f"frame_{saved_id:04d}.jpg")
156
+ cv2.imwrite(filename, frame)
157
+ prev_frame = frame
158
+ saved_id += 1
159
+
160
+ # append to a list that will constitute RAG Docuement
161
+ frame_data=get_frame_infos(filename)
162
+ all_frames_data.append(frame_data)
163
+ frame_id += 1
164
+
165
+ cap.release()
166
+ print(f"Extracted {saved_id} frames (1 per second).")
167
+ return all_frames_data
168
+
169
+ def get_frame_infos(filename:str) -> dict:
170
+ from load_vision_model_locally import VideoAnalyser
171
+ analyser = VideoAnalyser()
172
+
173
+ description = analyser.describe_frame(filename)
174
+ detection = analyser.detect_objects(filename)
175
+
176
+ print("description",type(description),description)
177
+ print("detection",type(detection),detection)
178
+
179
+ return (descrition, detection)