owinymarvin commited on
Commit
6c2ffca
·
1 Parent(s): 0ddc9dd

latest changes

Browse files
Files changed (2) hide show
  1. app.py +179 -198
  2. requirements.txt +5 -3
app.py CHANGED
@@ -1,206 +1,187 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoImageProcessor, TimesformerForVideoClassification
4
  import cv2
5
- from PIL import Image
6
  import numpy as np
7
- import time
8
- from collections import deque
9
- import base64
10
- import io
11
-
12
- # --- Configuration ---
13
- # CHANGED: Using a public Facebook TimesFormer model fine-tuned on Kinetics
14
- HF_MODEL_REPO_ID = "facebook/timesformer-base-finetuned-k400"
15
-
16
- MODEL_INPUT_NUM_FRAMES = 8
17
- TARGET_IMAGE_HEIGHT = 224
18
- TARGET_IMAGE_WIDTH = 224
19
- RAW_RECORDING_DURATION_SECONDS = 10.0
20
- FRAMES_TO_SAMPLE_PER_CLIP = 20
21
- DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # 2 minutes for CPU, adjust for GPU
22
-
23
- # --- Load Model and Processor ---
24
- print(f"Loading model and processor from {HF_MODEL_REPO_ID}...")
25
- try:
26
- processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
27
- model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
28
- except Exception as e:
29
- print(f"Error loading model: {e}")
30
- exit()
31
-
32
- model.eval()
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model.to(device)
35
- print(f"Model loaded on {device}.")
36
- print(f"Model's class labels (Kinetics): {model.config.id2label}") # Print new labels
37
-
38
- # --- Global State Variables for Live Demo ---
39
- raw_frames_buffer = deque()
40
- current_clip_start_time = time.time()
41
- last_prediction_completion_time = time.time()
42
- app_state = "recording" # States: "recording", "predicting", "processing_delay"
43
-
44
- # --- Helper function to sample frames ---
45
- def sample_frames(frames_list, target_count):
46
- if not frames_list:
47
- return []
48
- if len(frames_list) <= target_count:
49
- return frames_list
50
- indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
51
- sampled = [frames_list[int(i)] for i in indices]
52
- return sampled
53
-
54
- # --- Main processing function for Live Demo Stream ---
55
- def live_predict_stream(image_np_array):
56
- global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
57
-
58
- current_time = time.time()
59
- pil_image = Image.fromarray(image_np_array)
60
-
61
- if app_state == "recording":
62
- raw_frames_buffer.append(pil_image)
63
- elapsed_recording_time = current_time - current_clip_start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- yield f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}", "Buffering..."
66
-
67
- if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS:
68
- # Transition to predicting state
69
- app_state = "predicting"
70
- yield "Preparing to predict...", "Processing..."
71
- print("DEBUG: Transitioning to 'predicting' state.")
72
-
73
- elif app_state == "predicting":
74
- # Ensure this prediction block only runs once per cycle
75
- if raw_frames_buffer: # Only proceed if there are frames to process
76
- print("DEBUG: Starting prediction.")
77
- try:
78
- sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
79
- frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)
80
-
81
- if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
82
- yield "Error during frame sampling.", f"Error: Not enough frames ({len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}). Resetting."
83
- print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}. Resetting state.")
84
- app_state = "recording" # Reset state to start a new recording
85
- raw_frames_buffer.clear()
86
- current_clip_start_time = time.time()
87
- last_prediction_completion_time = time.time()
88
- return # Exit this stream call to wait for next frame or reset
89
-
90
- processed_input = processor(images=frames_for_model, return_tensors="pt")
91
- pixel_values = processed_input.pixel_values.to(device)
92
-
93
- with torch.no_grad():
94
- outputs = model(pixel_values)
95
- logits = outputs.logits
96
-
97
- predicted_class_id = logits.argmax(-1).item()
98
- predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
99
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
100
-
101
- prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
102
- status_message = "Prediction complete."
103
- print(f"DEBUG: Prediction Result: {prediction_result}")
104
-
105
- # Yield the prediction result immediately to ensure UI update
106
- yield status_message, prediction_result
107
-
108
- # Clear buffer and transition to delay AFTER yielding the prediction
109
- raw_frames_buffer.clear()
110
- last_prediction_completion_time = current_time
111
- app_state = "processing_delay"
112
- print("DEBUG: Transitioning to 'processing_delay' state.")
113
-
114
- except Exception as e:
115
- error_message = f"Error during prediction: {e}"
116
- print(f"ERROR during prediction: {e}")
117
- # Yield error to UI
118
- yield "Prediction error.", error_message
119
- app_state = "processing_delay" # Still go to delay state to prevent constant errors
120
- raw_frames_buffer.clear() # Clear buffer to prevent re-processing same problematic frames
121
-
122
- elif app_state == "processing_delay":
123
- elapsed_delay = current_time - last_prediction_completion_time
124
 
125
- if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS:
126
- # Continue yielding the delay message and the last prediction result
127
- yield f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s", gr.NO_VALUE
128
- else:
129
- # Delay is over, reset for new recording cycle
130
- app_state = "recording"
131
- current_clip_start_time = current_time
132
- print("DEBUG: Transitioning back to 'recording' state.")
133
- yield "Starting new recording...", "Ready for new prediction."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- pass
136
-
137
- def reset_app_state_manual():
138
- global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
139
- raw_frames_buffer.clear()
140
- current_clip_start_time = time.time()
141
- last_prediction_completion_time = time.time()
142
- app_state = "recording"
143
- print("DEBUG: Manual reset triggered.")
144
- # Return initial values immediately upon reset
145
- return "Ready to record...", "Ready for new prediction."
146
-
147
- # --- Gradio UI Layout ---
148
- with gr.Blocks() as demo:
149
- gr.Markdown(
150
- f"""
151
- # TimesFormer Action Recognition - Using Facebook Kinetics Model
152
- This Space hosts the `{HF_MODEL_REPO_ID}` model.
153
- Live webcam demo with recording and prediction phases.
154
- **NOTE: This model predicts general human actions (e.g., 'playing guitar', 'walking'), not crime events.**
155
- """
156
- )
157
-
158
- with gr.Tab("Live Webcam Demo"):
159
- gr.Markdown(
160
- f"""
161
- Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**,
162
- then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards.
163
- """
164
- )
165
- with gr.Row():
166
- with gr.Column():
167
- webcam_input = gr.Image(
168
- sources=["webcam"],
169
- streaming=True,
170
- label="Live Webcam Feed"
171
- )
172
- status_output = gr.Textbox(label="Current Status", value="Initializing...")
173
- reset_button = gr.Button("Reset / Start New Cycle")
174
- with gr.Column():
175
- prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...")
176
-
177
- webcam_input.stream(
178
- live_predict_stream,
179
- inputs=[webcam_input],
180
- outputs=[status_output, prediction_output]
181
- )
182
-
183
- reset_button.click(
184
- reset_app_state_manual,
185
- inputs=[],
186
- outputs=[status_output, prediction_output]
187
- )
188
-
189
- with gr.Tab("API Endpoint for External Clients"):
190
- gr.Markdown(
191
- """
192
- Use this API endpoint to send base64-encoded frames for prediction.
193
- (Currently uses the Kinetics model).
194
- """
195
- )
196
- gr.Interface(
197
- fn=lambda frames_list: "API endpoint is active for programmatic calls. See documentation in app.py.",
198
- inputs=gr.Json(label="List of Base64-encoded image strings"),
199
- outputs=gr.Textbox(label="API Response"),
200
- live=False,
201
- allow_flagging="never"
202
- )
203
-
204
-
205
- if __name__ == "__main__":
206
- demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import cv2
 
4
  import numpy as np
5
+ import os
6
+ import json
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from huggingface_hub import hf_hub_download
10
+ import tempfile # For temporary file handling
11
+
12
+ # --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
13
+ # This is crucial because we need the model class definition to load weights.
14
+ class SmallVideoClassifier(torch.nn.Module):
15
+ def __init__(self, num_classes=2, num_frames=8):
16
+ super(SmallVideoClassifier, self).__init__()
17
+ from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
18
+ # Load weights only if they are available (they should be for IMAGENET1K_V1)
19
+ # Use a check to prevent error if weights are not found in specific environments
20
+ try:
21
+ weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
22
+ except Exception:
23
+ print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
24
+ weights = None # Or provide specific default for your use case
25
+
26
+ self.feature_extractor = mobilenet_v3_small(weights=weights)
27
+ self.feature_extractor.classifier = torch.nn.Identity()
28
+ self.num_spatial_features = 576 # MobileNetV3-Small's final feature map size
29
+ self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
30
+ self.classifier = torch.nn.Sequential(
31
+ torch.nn.Linear(self.num_spatial_features, 512),
32
+ torch.nn.ReLU(),
33
+ torch.nn.Dropout(0.2),
34
+ torch.nn.Linear(512, num_classes)
35
+ )
36
+
37
+ def forward(self, pixel_values):
38
+ batch_size, num_frames, channels, height, width = pixel_values.shape
39
+ x = pixel_values.view(batch_size * num_frames, channels, height, width)
40
+ spatial_features = self.feature_extractor(x)
41
+ spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
42
+ temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
43
+ logits = self.classifier(temporal_features)
44
+ return logits
45
+
46
+ # --- 2. Configuration and Model Loading ---
47
+ HF_USERNAME = "owinymarvin"
48
+ NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
49
+ NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
50
+
51
+ # Download config.json to get model parameters
52
+ print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
53
+ config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
54
+ with open(config_path, 'r') as f:
55
+ model_config = json.load(f)
56
+
57
+ NUM_FRAMES = model_config.get('num_frames', 8)
58
+ IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
59
+ NUM_CLASSES = model_config.get('num_classes', 2)
60
+
61
+ # Define class labels (adjust if your dataset had different labels/order)
62
+ CLASS_LABELS = ["Non-violence", "Violence"]
63
+ if NUM_CLASSES != len(CLASS_LABELS):
64
+ print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
65
+
66
+
67
+ # Initialize the model
68
+ device = torch.device("cpu") # Explicitly use CPU as requested
69
+ print(f"Using device: {device}")
70
+
71
+ model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
72
+
73
+ # Download model weights
74
+ print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
75
+ model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
76
+ model.load_state_dict(torch.load(model_weights_path, map_location=device))
77
  model.to(device)
78
+ model.eval() # Set model to evaluation mode
79
+
80
+ print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
81
+
82
+ # --- 3. Define Preprocessing Transform ---
83
+ transform = transforms.Compose([
84
+ transforms.Resize(IMAGE_SIZE),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
87
+ ])
88
+
89
+ # --- 4. Gradio Inference Function ---
90
+ def predict_video(video_path):
91
+ if video_path is None:
92
+ return None # Or raise an error, or return a placeholder video
93
+
94
+ cap = cv2.VideoCapture(video_path)
95
+
96
+ if not cap.isOpened():
97
+ print(f"Error: Could not open video file {video_path}.")
98
+ raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")
99
+
100
+ # Get video properties
101
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103
+ fps = cap.get(cv2.CAP_PROP_FPS)
104
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
105
+
106
+ # Create a temporary output video file
107
+ # Use tempfile to ensure proper cleanup on Hugging Face Spaces
108
+ temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
109
+ output_video_path = temp_output_file.name
110
+ temp_output_file.close() # Close the file handle as cv2.VideoWriter needs to open it
111
+
112
+ # Define the codec and create VideoWriter object
113
+ # For MP4, 'mp4v' is generally compatible.
114
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
115
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
116
+
117
+ print(f"Processing video: {video_path}")
118
+ print(f"Total frames: {total_frames}, FPS: {fps}")
119
+ print(f"Output video will be saved to: {output_video_path}")
120
+
121
+ frame_buffer = [] # To store NUM_FRAMES for each prediction batch
122
+ current_prediction_label = "Processing..." # Initial label
123
+
124
+ frame_idx = 0
125
+ while True:
126
+ ret, frame = cap.read()
127
+ if not ret:
128
+ break # End of video
129
+
130
+ frame_idx += 1
131
 
132
+ # Convert frame from BGR (OpenCV) to RGB (PIL/PyTorch)
133
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
134
+ pil_image = Image.fromarray(frame_rgb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Apply transformations and add to buffer
137
+ processed_frame = transform(pil_image) # shape: (C, H, W)
138
+ frame_buffer.append(processed_frame)
139
+
140
+ # Perform prediction when the buffer is full
141
+ if len(frame_buffer) == NUM_FRAMES:
142
+ # Stack the buffered frames and add a batch dimension
143
+ # Resulting shape: (1, NUM_FRAMES, C, H, W)
144
+ input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
145
+
146
+ with torch.no_grad():
147
+ outputs = model(input_tensor)
148
+ probabilities = torch.softmax(outputs, dim=1)
149
+ predicted_class_idx = torch.argmax(probabilities, dim=1).item()
150
+ current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
151
+
152
+ # Reset buffer for the next non-overlapping window
153
+ frame_buffer = []
154
+ # Or, if you want sliding window (more continuous output but higher compute):
155
+ # frame_buffer = frame_buffer[1:] # e.g., slide by 1 frame
156
+
157
+ # Draw prediction text on the current frame
158
+ # The prediction will lag by NUM_FRAMES, as it's based on the previous batch.
159
+ # We display the last known prediction.
160
+ cv2.putText(frame, current_prediction_label, (10, 30),
161
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2, cv2.LINE_AA)
162
+
163
+ # Write the processed frame to the output video
164
+ out.write(frame)
165
+
166
+ # Release resources
167
+ cap.release()
168
+ out.release()
169
+ print(f"Video processing complete. Output saved to: {output_video_path}")
170
 
171
+ return output_video_path # Gradio expects the path to the output video
172
+
173
+ # --- 5. Gradio Interface Setup ---
174
+ iface = gr.Interface(
175
+ fn=predict_video,
176
+ inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)", type="filepath"),
177
+ outputs=gr.Video(label="Processed Video with Predictions"),
178
+ title="Real-time Violence Detection with SmallVideoClassifier",
179
+ description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
180
+ allow_flagging="never", # Disable flagging on Hugging Face Spaces
181
+ examples=[
182
+ # You can provide example video URLs or paths if you have them publicly available
183
+ # Example: "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
184
+ ]
185
+ )
186
+
187
+ iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
  torch
2
- transformers
3
- opencv-python-headless # Use headless for server environments without display
 
 
4
  Pillow
5
- gradio
 
1
  torch
2
+ torchvision
3
+ opencv-python-headless # Use headless for server environments to avoid GUI dependencies
4
+ gradio
5
+ huggingface_hub
6
  Pillow
7
+ numpy