owinymarvin commited on
Commit
89b0c64
·
1 Parent(s): 4ad907b

latest changes

Browse files
Files changed (3) hide show
  1. app copy.py +0 -101
  2. app.py +88 -97
  3. good copy.py +175 -203
app copy.py DELETED
@@ -1,101 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoImageProcessor, TimesformerForVideoClassification
4
- import cv2
5
- from PIL import Image
6
- import numpy as np
7
- import time
8
- from collections import deque
9
-
10
- # --- Configuration ---
11
- # Your Hugging Face model repository ID
12
- HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
-
14
- # These must match the values used during your training
15
- NUM_FRAMES = 8
16
- TARGET_IMAGE_HEIGHT = 224
17
- TARGET_IMAGE_WIDTH = 224
18
-
19
- # --- Load Model and Processor ---
20
- print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
21
- try:
22
- processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
23
- model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
24
- except Exception as e:
25
- print(f"Error loading model from Hugging Face Hub: {e}")
26
- # Handle error - exit or raise exception for Space to fail gracefully
27
- exit()
28
-
29
- model.eval() # Set model to evaluation mode
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- model.to(device)
32
- print(f"Model loaded successfully on {device}.")
33
- print(f"Model's class labels: {model.config.id2label}")
34
-
35
- # Initialize a global buffer for frames for the session
36
- # Use a deque for efficient appending/popping from both ends
37
- frame_buffer = deque(maxlen=NUM_FRAMES)
38
- last_inference_time = time.time()
39
- inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
40
- current_prediction_text = "Buffering frames..." # Initialize global text
41
-
42
- def predict_video_frame(image_np_array):
43
- global frame_buffer, last_inference_time, current_prediction_text
44
-
45
- # Gradio sends frames as numpy arrays (RGB)
46
- # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
47
- pil_image = Image.fromarray(image_np_array)
48
- frame_buffer.append(pil_image)
49
-
50
- current_time = time.time()
51
-
52
- # Only perform inference if we have enough frames and it's time for a prediction
53
- if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
54
- last_inference_time = current_time
55
-
56
- # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
57
- # It will handle resizing and normalization based on its config
58
- processed_input = processor(images=list(frame_buffer), return_tensors="pt")
59
- pixel_values = processed_input.pixel_values.to(device)
60
-
61
- with torch.no_grad():
62
- outputs = model(pixel_values)
63
- logits = outputs.logits
64
-
65
- predicted_class_id = logits.argmax(-1).item()
66
- predicted_label = model.config.id2label[predicted_class_id]
67
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
68
-
69
- current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
70
- print(current_prediction_text) # Print to Space logs
71
-
72
- # Return the current prediction text for display in the UI
73
- # Gradio's streaming will update this textbox asynchronously
74
- return current_prediction_text
75
-
76
- # --- Gradio Interface ---
77
- # Create a streaming input for webcam
78
- webcam_input = gr.Image(
79
- sources=["webcam"], # Allows webcam input
80
- streaming=True, # Enables continuous streaming of frames
81
- # REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError
82
- label="Live Webcam Feed"
83
- )
84
-
85
- # Output text box for predictions
86
- prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...")
87
-
88
-
89
- # Define the Gradio Interface
90
- demo = gr.Interface(
91
- fn=predict_video_frame,
92
- inputs=webcam_input,
93
- outputs=prediction_output,
94
- live=True, # Enable live updates
95
- allow_flagging="never", # Disable flagging on public demo
96
- title="TimesFormer Crime Detection Live Demo",
97
- description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
98
- )
99
-
100
- if __name__ == "__main__":
101
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -7,10 +7,9 @@ import json
7
  from PIL import Image
8
  from torchvision import transforms
9
  from huggingface_hub import hf_hub_download
10
- import tempfile # For temporary file handling
11
 
12
  # --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
13
- # This is crucial because we need the model class definition to load weights.
14
  class SmallVideoClassifier(torch.nn.Module):
15
  def __init__(self, num_classes=2, num_frames=8):
16
  super(SmallVideoClassifier, self).__init__()
@@ -59,7 +58,7 @@ CLASS_LABELS = ["Non-violence", "Violence"]
59
  if NUM_CLASSES != len(CLASS_LABELS):
60
  print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
61
 
62
- device = torch.device("cpu")
63
  print(f"Using device: {device}")
64
 
65
  model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
@@ -79,106 +78,98 @@ transform = transforms.Compose([
79
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
80
  ])
81
 
82
- # --- 4. Gradio Inference Function ---
83
- def predict_video(video_path):
84
- if video_path is None:
85
- return None
86
-
87
- cap = cv2.VideoCapture(video_path)
88
-
89
- if not cap.isOpened():
90
- print(f"Error: Could not open video file {video_path}.")
91
- raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")
92
-
93
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
- fps = cap.get(cv2.CAP_PROP_FPS)
96
- # Ensure FPS is not zero to avoid division by zero errors, default to 25 if needed
97
- if fps <= 0:
98
- fps = 25.0
99
- print(f"Warning: Original video FPS was 0 or less, defaulting to {fps}.")
100
-
101
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
102
-
103
- temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
104
- output_video_path = temp_output_file.name
105
- temp_output_file.close()
106
-
107
- # --- CHANGED: Use XVID codec for better browser compatibility ---
108
- # This might prevent Gradio's internal re-encoding.
109
- fourcc = cv2.VideoWriter_fourcc(*'XVID')
110
- out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
111
-
112
- print(f"Processing video: {video_path}")
113
- print(f"Total frames: {total_frames}, FPS: {fps}")
114
- print(f"Output video will be saved to: {output_video_path}")
115
-
116
- frame_buffer = []
117
- current_prediction_label = "Processing..."
118
-
119
- frame_idx = 0
120
- while True:
121
- ret, frame = cap.read()
122
- if not ret:
123
- break
124
-
125
- frame_idx += 1
126
-
127
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
128
- pil_image = Image.fromarray(frame_rgb)
129
-
130
- processed_frame = transform(pil_image)
131
- frame_buffer.append(processed_frame)
132
-
133
- if len(frame_buffer) == NUM_FRAMES:
134
- input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
135
-
136
- with torch.no_grad():
137
- outputs = model(input_tensor)
138
- probabilities = torch.softmax(outputs, dim=1)
139
- predicted_class_idx = torch.argmax(probabilities, dim=1).item()
140
- current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
141
 
142
- frame_buffer = []
143
- # If you want a sliding window, you would do something like:
144
- # frame_buffer = frame_buffer[int(NUM_FRAMES * 0.5):] # Slide by half the window size
145
-
146
- # Draw prediction text on the current frame
147
- # Ensure text color is clearly visible (e.g., white or bright green)
148
- # Add a black outline for better readability
149
- text_color = (0, 255, 0) # Green (BGR format for OpenCV)
150
- text_outline_color = (0, 0, 0) # Black
151
- font_scale = 1.0 # Increased font size
152
- font_thickness = 2
153
 
154
- # Draw outline first for better readability
155
- cv2.putText(frame, current_prediction_label, (10, 40), # Slightly lower position
156
- cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
157
- # Draw actual text
158
- cv2.putText(frame, current_prediction_label, (10, 40),
159
- cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
160
-
161
- out.write(frame)
162
-
163
- cap.release()
164
- out.release()
165
- print(f"Video processing complete. Output saved to: {output_video_path}")
 
 
 
 
 
 
166
 
167
- return output_video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # --- 5. Gradio Interface Setup ---
170
  iface = gr.Interface(
171
- fn=predict_video,
172
- inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
173
- outputs=gr.Video(label="Processed Video with Predictions"),
174
- title="Real-time Violence Detection with SmallVideoClassifier",
175
- description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
176
- allow_flagging="never",
177
- examples=[
178
- # Add example videos here for easier testing and demonstration
179
- # E.g., a sample video that's publicly accessible:
180
- # "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
181
- ]
 
182
  )
183
 
184
  iface.launch()
 
7
  from PIL import Image
8
  from torchvision import transforms
9
  from huggingface_hub import hf_hub_download
10
+ import time # For potential sleep to control frame rate if needed
11
 
12
  # --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
 
13
  class SmallVideoClassifier(torch.nn.Module):
14
  def __init__(self, num_classes=2, num_frames=8):
15
  super(SmallVideoClassifier, self).__init__()
 
58
  if NUM_CLASSES != len(CLASS_LABELS):
59
  print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
60
 
61
+ device = torch.device("cpu") # Explicitly use CPU
62
  print(f"Using device: {device}")
63
 
64
  model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
 
78
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
  ])
80
 
81
+ # --- 4. Gradio Live Inference Function (Generator) ---
82
+ # This function will receive individual frames from the webcam
83
+ def predict_live_frames(input_frame):
84
+ global frame_buffer, current_prediction_label, current_probabilities # Use global to maintain state across calls
85
+
86
+ if input_frame is None:
87
+ # If no frame is received (e.g., webcam not active), yield a black frame or handle gracefully
88
+ dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
89
+ cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
90
+ yield dummy_frame
91
+ return # Exit if no frame to process
92
+
93
+ # Gradio Webcam gives NumPy array (H, W, C) in RGB
94
+ pil_image = Image.fromarray(input_frame)
95
+
96
+ # Apply transformations (outputs C, H, W tensor)
97
+ processed_frame_tensor = transform(pil_image)
98
+ frame_buffer.append(processed_frame_tensor)
99
+
100
+ # Perform prediction only when the buffer is full
101
+ if len(frame_buffer) == NUM_FRAMES:
102
+ # Stack the buffered frames and add a batch dimension
103
+ input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
104
+
105
+ with torch.no_grad():
106
+ outputs = model(input_tensor)
107
+ probabilities = torch.softmax(outputs, dim=1)
108
+ predicted_class_idx = torch.argmax(probabilities, dim=1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
111
+ current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
 
 
 
 
 
 
 
 
 
112
 
113
+ # --- Sliding Window ---
114
+ # Keep the last few frames to allow continuous predictions
115
+ # For example, if NUM_FRAMES is 8, and we want a new prediction every 2 frames,
116
+ # we slide the window by 2:
117
+ slide_window_by = 1 # Predict every frame (most "real-time" feel but highest compute)
118
+ # Or: NUM_FRAMES // 2 (e.g., predict every 4 frames for NUM_FRAMES=8)
119
+ # Or: NUM_FRAMES (non-overlapping windows, less frequent updates)
120
+ frame_buffer = frame_buffer[slide_window_by:]
121
+
122
+ # --- Draw Prediction on the current input frame ---
123
+ # Convert the input_frame (RGB NumPy array) to BGR for OpenCV drawing
124
+ display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
125
+
126
+ # Draw the main prediction label
127
+ text_color = (0, 255, 0) # Green (BGR)
128
+ text_outline_color = (0, 0, 0) # Black
129
+ font_scale = 1.0
130
+ font_thickness = 2
131
 
132
+ # Draw outline first for better readability
133
+ cv2.putText(display_frame, current_prediction_label, (10, 40),
134
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
135
+ # Draw actual text
136
+ cv2.putText(display_frame, current_prediction_label, (10, 40),
137
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
138
+
139
+ # Draw probabilities for all classes (like YOLO)
140
+ y_offset = 80 # Start drawing probabilities slightly lower
141
+ for label, prob in current_probabilities.items():
142
+ prob_text = f"{label}: {prob:.2f}"
143
+ cv2.putText(display_frame, prob_text, (10, y_offset),
144
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
145
+ cv2.putText(display_frame, prob_text, (10, y_offset),
146
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA) # Yellow for probs
147
+ y_offset += 30 # Move down for next probability
148
+
149
+ # Yield the processed frame back to Gradio for display
150
+ # Gradio expects RGB NumPy array for video/image components
151
+ yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
152
+
153
+
154
+ # --- Initialize global state for the generator function ---
155
+ frame_buffer = [] # Buffer for collecting frames for model input
156
+ current_prediction_label = "Initializing..."
157
+ current_probabilities = {label: 0.0 for label in CLASS_LABELS} # Initial probabilities
158
 
159
  # --- 5. Gradio Interface Setup ---
160
  iface = gr.Interface(
161
+ fn=predict_live_frames,
162
+ # Use gr.Webcam for direct webcam input
163
+ inputs=gr.Webcam(streaming=True, label="Live Webcam Feed for Violence Detection"),
164
+ # Outputs are updated continuously by the generator
165
+ outputs=gr.Image(type="numpy", label="Live Prediction Output"), # Using Image as output for continuous frames
166
+ title="Real-time Violence Detection with SmallVideoClassifier (Webcam)",
167
+ description=(
168
+ "This model detects violence in a live webcam feed. "
169
+ "Predictions (Class and Probabilities) will be displayed on each frame. "
170
+ "Please allow webcam access when prompted."
171
+ ),
172
+ allow_flagging="never", # Disable flagging on Hugging Face Spaces
173
  )
174
 
175
  iface.launch()
good copy.py CHANGED
@@ -1,212 +1,184 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoImageProcessor, TimesformerForVideoClassification
4
  import cv2
5
- from PIL import Image
6
  import numpy as np
7
- import time
8
- from collections import deque
9
- import base64
10
- import io
11
-
12
- # --- Configuration ---
13
- HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
14
- MODEL_INPUT_NUM_FRAMES = 8
15
- TARGET_IMAGE_HEIGHT = 224
16
- TARGET_IMAGE_WIDTH = 224
17
- RAW_RECORDING_DURATION_SECONDS = 10.0
18
- FRAMES_TO_SAMPLE_PER_CLIP = 20
19
- DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # 2 minutes for CPU
20
-
21
- # --- Load Model and Processor ---
22
- print(f"Loading model and processor from {HF_MODEL_REPO_ID}...")
23
- try:
24
- processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
25
- model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
26
- except Exception as e:
27
- print(f"Error loading model: {e}")
28
- exit()
 
 
 
 
 
 
 
29
 
30
- model.eval()
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  model.to(device)
33
- print(f"Model loaded on {device}.")
34
-
35
- # --- Global State Variables for Live Demo ---
36
- raw_frames_buffer = deque()
37
- current_clip_start_time = time.time()
38
- last_prediction_completion_time = time.time()
39
- app_state = "recording" # States: "recording", "predicting", "processing_delay"
40
-
41
- # --- Helper function to sample frames ---
42
- def sample_frames(frames_list, target_count):
43
- if not frames_list:
44
- return []
45
- if len(frames_list) <= target_count:
46
- return frames_list
47
- indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
48
- sampled = [frames_list[int(i)] for i in indices]
49
- return sampled
50
-
51
- # --- Main processing function for Live Demo Stream ---
52
- def live_predict_stream(image_np_array):
53
- global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
54
-
55
- current_time = time.time()
56
- pil_image = Image.fromarray(image_np_array)
57
-
58
- if app_state == "recording":
59
- raw_frames_buffer.append(pil_image)
60
- elapsed_recording_time = current_time - current_clip_start_time
61
-
62
- yield f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}", "Buffering..."
63
-
64
- if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS:
65
- # Transition to predicting state
66
- app_state = "predicting"
67
- yield "Preparing to predict...", "Processing..."
68
- print("DEBUG: Transitioning to 'predicting' state.")
69
-
70
- elif app_state == "predicting":
71
- # Ensure this prediction block only runs once per cycle
72
- if raw_frames_buffer: # Only proceed if there are frames to process
73
- print("DEBUG: Starting prediction.")
74
- try:
75
- sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
76
- frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)
77
-
78
- if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
79
- yield "Error during frame sampling.", f"Error: Not enough frames ({len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}). Resetting."
80
- print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}. Resetting state.")
81
- app_state = "recording" # Reset state to start a new recording
82
- raw_frames_buffer.clear()
83
- current_clip_start_time = time.time()
84
- last_prediction_completion_time = time.time()
85
- return # Exit this stream call to wait for next frame or reset
86
-
87
- processed_input = processor(images=frames_for_model, return_tensors="pt")
88
- pixel_values = processed_input.pixel_values.to(device)
89
-
90
- with torch.no_grad():
91
- outputs = model(pixel_values)
92
- logits = outputs.logits
93
-
94
- predicted_class_id = logits.argmax(-1).item()
95
- predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
96
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
97
-
98
- prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
99
- status_message = "Prediction complete."
100
- print(f"DEBUG: Prediction Result: {prediction_result}")
101
-
102
- # Yield the prediction result immediately to ensure UI update
103
- yield status_message, prediction_result
104
-
105
- # Clear buffer and transition to delay AFTER yielding the prediction
106
- raw_frames_buffer.clear()
107
- last_prediction_completion_time = current_time
108
- app_state = "processing_delay"
109
- print("DEBUG: Transitioning to 'processing_delay' state.")
110
-
111
- except Exception as e:
112
- error_message = f"Error during prediction: {e}"
113
- print(f"ERROR during prediction: {e}")
114
- # Yield error to UI
115
- yield "Prediction error.", error_message
116
- app_state = "processing_delay" # Still go to delay state to prevent constant errors
117
- raw_frames_buffer.clear() # Clear buffer to prevent re-processing same problematic frames
118
-
119
- elif app_state == "processing_delay":
120
- elapsed_delay = current_time - last_prediction_completion_time
121
-
122
- if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS:
123
- # Continue yielding the delay message and the last prediction result
124
- # Assuming prediction_result from previous state is still held by UI
125
- yield f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s", gr.NO_VALUE # NO_VALUE keeps previous prediction visible
126
- else:
127
- # Delay is over, reset for new recording cycle
128
- app_state = "recording"
129
- current_clip_start_time = current_time
130
- print("DEBUG: Transitioning back to 'recording' state.")
131
- yield "Starting new recording...", "Ready for new prediction."
132
-
133
- # If for some reason nothing is yielded, return the current state to prevent UI freeze.
134
- # This acts as a fallback if no state transition happens.
135
- # However, with the yield statements, this might be less critical.
136
- # For streaming, yielding is the preferred way to update.
137
- # If the function ends without yielding, Gradio will just keep the last state.
138
- # We always yield in every branch.
139
- pass # No explicit return needed at the end if all paths yield
140
-
141
- def reset_app_state_manual():
142
- global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
143
- raw_frames_buffer.clear()
144
- current_clip_start_time = time.time()
145
- last_prediction_completion_time = time.time()
146
- app_state = "recording"
147
- print("DEBUG: Manual reset triggered.")
148
- # Return initial values immediately upon reset
149
- return "Ready to record...", "Ready for new prediction."
150
-
151
- # --- Gradio UI Layout ---
152
- with gr.Blocks() as demo:
153
- gr.Markdown(
154
- f"""
155
- # TimesFormer Crime Detection - Hugging Face Space Host
156
- This Space hosts the `owinymarvin/timesformer-crime-detection` model.
157
- Live webcam demo with recording and prediction phases.
158
- """
159
- )
160
-
161
- with gr.Tab("Live Webcam Demo"):
162
- gr.Markdown(
163
- f"""
164
- Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**,
165
- then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards.
166
- """
167
- )
168
- with gr.Row():
169
- with gr.Column():
170
- webcam_input = gr.Image(
171
- sources=["webcam"],
172
- streaming=True,
173
- label="Live Webcam Feed"
174
- )
175
- status_output = gr.Textbox(label="Current Status", value="Initializing...")
176
- reset_button = gr.Button("Reset / Start New Cycle")
177
- with gr.Column():
178
- prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...")
179
-
180
- # IMPORTANT: Use webcam_input.stream() with a generator function (live_predict_stream)
181
- # to enable progressive updates via 'yield'.
182
- webcam_input.stream(
183
- live_predict_stream,
184
- inputs=[webcam_input],
185
- outputs=[status_output, prediction_output]
186
- )
187
-
188
- # The reset button is a regular click event, not a stream
189
- reset_button.click(
190
- reset_app_state_manual,
191
- inputs=[],
192
- outputs=[status_output, prediction_output]
193
- )
194
 
195
- with gr.Tab("API Endpoint for External Clients"):
196
- gr.Markdown(
197
- """
198
- Use this API endpoint to send base64-encoded frames for prediction.
199
- """
200
- )
201
- # Placeholder for the API tab. The actual API calls target /run/predict_from_frames_api
202
- gr.Interface(
203
- fn=lambda frames_list: "API endpoint is active for programmatic calls. See documentation in app.py.",
204
- inputs=gr.Json(label="List of Base64-encoded image strings"),
205
- outputs=gr.Textbox(label="API Response"),
206
- live=False,
207
- allow_flagging="never"
208
- )
209
 
 
 
 
 
 
 
210
 
211
- if __name__ == "__main__":
212
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
 
3
  import cv2
 
4
  import numpy as np
5
+ import os
6
+ import json
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from huggingface_hub import hf_hub_download
10
+ import tempfile # For temporary file handling
11
+
12
+ # --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
13
+ # This is crucial because we need the model class definition to load weights.
14
+ class SmallVideoClassifier(torch.nn.Module):
15
+ def __init__(self, num_classes=2, num_frames=8):
16
+ super(SmallVideoClassifier, self).__init__()
17
+ from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
18
+ try:
19
+ weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
20
+ except Exception:
21
+ print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
22
+ weights = None
23
+
24
+ self.feature_extractor = mobilenet_v3_small(weights=weights)
25
+ self.feature_extractor.classifier = torch.nn.Identity()
26
+ self.num_spatial_features = 576
27
+ self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
28
+ self.classifier = torch.nn.Sequential(
29
+ torch.nn.Linear(self.num_spatial_features, 512),
30
+ torch.nn.ReLU(),
31
+ torch.nn.Dropout(0.2),
32
+ torch.nn.Linear(512, num_classes)
33
+ )
34
 
35
+ def forward(self, pixel_values):
36
+ batch_size, num_frames, channels, height, width = pixel_values.shape
37
+ x = pixel_values.view(batch_size * num_frames, channels, height, width)
38
+ spatial_features = self.feature_extractor(x)
39
+ spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
40
+ temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
41
+ logits = self.classifier(temporal_features)
42
+ return logits
43
+
44
+ # --- 2. Configuration and Model Loading ---
45
+ HF_USERNAME = "owinymarvin"
46
+ NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
47
+ NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
48
+
49
+ print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
50
+ config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
51
+ with open(config_path, 'r') as f:
52
+ model_config = json.load(f)
53
+
54
+ NUM_FRAMES = model_config.get('num_frames', 8)
55
+ IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
56
+ NUM_CLASSES = model_config.get('num_classes', 2)
57
+
58
+ CLASS_LABELS = ["Non-violence", "Violence"]
59
+ if NUM_CLASSES != len(CLASS_LABELS):
60
+ print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
61
+
62
+ device = torch.device("cpu")
63
+ print(f"Using device: {device}")
64
+
65
+ model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
66
+
67
+ print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
68
+ model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
69
+ model.load_state_dict(torch.load(model_weights_path, map_location=device))
70
  model.to(device)
71
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # --- 3. Define Preprocessing Transform ---
76
+ transform = transforms.Compose([
77
+ transforms.Resize(IMAGE_SIZE),
78
+ transforms.ToTensor(),
79
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
80
+ ])
81
 
82
+ # --- 4. Gradio Inference Function ---
83
+ def predict_video(video_path):
84
+ if video_path is None:
85
+ return None
86
+
87
+ cap = cv2.VideoCapture(video_path)
88
+
89
+ if not cap.isOpened():
90
+ print(f"Error: Could not open video file {video_path}.")
91
+ raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")
92
+
93
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
+ fps = cap.get(cv2.CAP_PROP_FPS)
96
+ # Ensure FPS is not zero to avoid division by zero errors, default to 25 if needed
97
+ if fps <= 0:
98
+ fps = 25.0
99
+ print(f"Warning: Original video FPS was 0 or less, defaulting to {fps}.")
100
+
101
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
102
+
103
+ temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
104
+ output_video_path = temp_output_file.name
105
+ temp_output_file.close()
106
+
107
+ # --- CHANGED: Use XVID codec for better browser compatibility ---
108
+ # This might prevent Gradio's internal re-encoding.
109
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
110
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
111
+
112
+ print(f"Processing video: {video_path}")
113
+ print(f"Total frames: {total_frames}, FPS: {fps}")
114
+ print(f"Output video will be saved to: {output_video_path}")
115
+
116
+ frame_buffer = []
117
+ current_prediction_label = "Processing..."
118
+
119
+ frame_idx = 0
120
+ while True:
121
+ ret, frame = cap.read()
122
+ if not ret:
123
+ break
124
+
125
+ frame_idx += 1
126
+
127
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
128
+ pil_image = Image.fromarray(frame_rgb)
129
+
130
+ processed_frame = transform(pil_image)
131
+ frame_buffer.append(processed_frame)
132
+
133
+ if len(frame_buffer) == NUM_FRAMES:
134
+ input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
135
+
136
+ with torch.no_grad():
137
+ outputs = model(input_tensor)
138
+ probabilities = torch.softmax(outputs, dim=1)
139
+ predicted_class_idx = torch.argmax(probabilities, dim=1).item()
140
+ current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
141
+
142
+ frame_buffer = []
143
+ # If you want a sliding window, you would do something like:
144
+ # frame_buffer = frame_buffer[int(NUM_FRAMES * 0.5):] # Slide by half the window size
145
+
146
+ # Draw prediction text on the current frame
147
+ # Ensure text color is clearly visible (e.g., white or bright green)
148
+ # Add a black outline for better readability
149
+ text_color = (0, 255, 0) # Green (BGR format for OpenCV)
150
+ text_outline_color = (0, 0, 0) # Black
151
+ font_scale = 1.0 # Increased font size
152
+ font_thickness = 2
153
+
154
+ # Draw outline first for better readability
155
+ cv2.putText(frame, current_prediction_label, (10, 40), # Slightly lower position
156
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
157
+ # Draw actual text
158
+ cv2.putText(frame, current_prediction_label, (10, 40),
159
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
160
+
161
+ out.write(frame)
162
+
163
+ cap.release()
164
+ out.release()
165
+ print(f"Video processing complete. Output saved to: {output_video_path}")
166
+
167
+ return output_video_path
168
+
169
+ # --- 5. Gradio Interface Setup ---
170
+ iface = gr.Interface(
171
+ fn=predict_video,
172
+ inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
173
+ outputs=gr.Video(label="Processed Video with Predictions"),
174
+ title="Real-time Violence Detection with SmallVideoClassifier",
175
+ description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
176
+ allow_flagging="never",
177
+ examples=[
178
+ # Add example videos here for easier testing and demonstration
179
+ # E.g., a sample video that's publicly accessible:
180
+ # "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
181
+ ]
182
+ )
183
+
184
+ iface.launch()