owinymarvin commited on
Commit
429fbf1
·
1 Parent(s): 998f789

latest changes

Browse files
Files changed (1) hide show
  1. app.py +65 -52
app.py CHANGED
@@ -7,9 +7,9 @@ import json
7
  from PIL import Image
8
  from torchvision import transforms
9
  from huggingface_hub import hf_hub_download
10
- import time # For potential sleep to control frame rate if needed
11
 
12
- # --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
13
  class SmallVideoClassifier(torch.nn.Module):
14
  def __init__(self, num_classes=2, num_frames=8):
15
  super(SmallVideoClassifier, self).__init__()
@@ -58,7 +58,7 @@ CLASS_LABELS = ["Non-violence", "Violence"]
58
  if NUM_CLASSES != len(CLASS_LABELS):
59
  print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
60
 
61
- device = torch.device("cpu") # Explicitly use CPU
62
  print(f"Using device: {device}")
63
 
64
  model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
@@ -78,34 +78,30 @@ transform = transforms.Compose([
78
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
  ])
80
 
81
- # --- 4. Gradio Live Inference Function (Generator) ---
82
- # This function will receive individual frames from the webcam
83
- # Initialize global state for the generator function (before the predict function)
84
- frame_buffer = [] # Buffer for collecting frames for model input
85
  current_prediction_label = "Initializing..."
86
- current_probabilities = {label: 0.0 for label in CLASS_LABELS} # Initial probabilities
87
 
 
88
  def predict_live_frames(input_frame):
89
- global frame_buffer, current_prediction_label, current_probabilities # Use global to maintain state across calls
90
 
91
  if input_frame is None:
92
- # If no frame is received (e.g., webcam not active), yield a black frame or handle gracefully
93
  dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
94
  cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
95
  yield dummy_frame
96
- return # Exit if no frame to process
97
 
98
- # Gradio Webcam gives NumPy array (H, W, C) in RGB
99
  pil_image = Image.fromarray(input_frame)
100
-
101
- # Apply transformations (outputs C, H, W tensor)
102
  processed_frame_tensor = transform(pil_image)
103
  frame_buffer.append(processed_frame_tensor)
104
 
105
- # Perform prediction only when the buffer is full
106
- if len(frame_buffer) == NUM_FRAMES:
107
- # Stack the buffered frames and add a batch dimension
108
- input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
109
 
110
  with torch.no_grad():
111
  outputs = model(input_tensor)
@@ -115,15 +111,8 @@ def predict_live_frames(input_frame):
115
  current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
116
  current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
117
 
118
- # --- Sliding Window ---
119
- # Keep the last few frames to allow continuous predictions
120
- slide_window_by = 1 # Predict every frame (most "real-time" feel but highest compute)
121
- # Or: NUM_FRAMES // 2 (e.g., predict every 4 frames for NUM_FRAMES=8)
122
- # Or: NUM_FRAMES (non-overlapping windows, less frequent updates)
123
- frame_buffer = frame_buffer[slide_window_by:]
124
-
125
- # --- Draw Prediction on the current input frame ---
126
- # Convert the input_frame (RGB NumPy array) to BGR for OpenCV drawing
127
  display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
128
 
129
  # Draw the main prediction label
@@ -132,42 +121,66 @@ def predict_live_frames(input_frame):
132
  font_scale = 1.0
133
  font_thickness = 2
134
 
135
- # Draw outline first for better readability
136
  cv2.putText(display_frame, current_prediction_label, (10, 40),
137
  cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
138
- # Draw actual text
139
  cv2.putText(display_frame, current_prediction_label, (10, 40),
140
  cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
141
 
142
- # Draw probabilities for all classes (like YOLO)
143
- y_offset = 80 # Start drawing probabilities slightly lower
144
  for label, prob in current_probabilities.items():
145
  prob_text = f"{label}: {prob:.2f}"
146
  cv2.putText(display_frame, prob_text, (10, y_offset),
147
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
148
  cv2.putText(display_frame, prob_text, (10, y_offset),
149
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA) # Yellow for probs
150
- y_offset += 30 # Move down for next probability
151
 
152
- # Yield the processed frame back to Gradio for display
153
- # Gradio expects RGB NumPy array for video/image components
154
  yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- # --- 5. Gradio Interface Setup ---
158
- iface = gr.Interface(
159
- fn=predict_live_frames,
160
- # CORRECTED: Use gr.Video with sources=["webcam"] for webcam input
161
- inputs=gr.Video(sources=["webcam"], streaming=True, label="Live Webcam Feed for Violence Detection"),
162
- # Outputs are updated continuously by the generator
163
- outputs=gr.Image(type="numpy", label="Live Prediction Output"), # Using Image as output for continuous frames
164
- title="Real-time Violence Detection with SmallVideoClassifier (Webcam)",
165
- description=(
166
- "This model detects violence in a live webcam feed. "
167
- "Predictions (Class and Probabilities) will be displayed on each frame. "
168
- "Please allow webcam access when prompted."
169
- ),
170
- allow_flagging="never", # Disable flagging on Hugging Face Spaces
171
- )
172
-
173
- iface.launch()
 
7
  from PIL import Image
8
  from torchvision import transforms
9
  from huggingface_hub import hf_hub_download
10
+ import time
11
 
12
+ # --- 1. Define Model Architecture ---
13
  class SmallVideoClassifier(torch.nn.Module):
14
  def __init__(self, num_classes=2, num_frames=8):
15
  super(SmallVideoClassifier, self).__init__()
 
58
  if NUM_CLASSES != len(CLASS_LABELS):
59
  print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
60
 
61
+ device = torch.device("cpu")
62
  print(f"Using device: {device}")
63
 
64
  model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
 
78
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
  ])
80
 
81
+ # --- Global state for the generator function ---
82
+ frame_buffer = []
 
 
83
  current_prediction_label = "Initializing..."
84
+ current_probabilities = {label: 0.0 for label in CLASS_LABELS}
85
 
86
+ # --- 4. Gradio Live Inference Function (Generator) ---
87
  def predict_live_frames(input_frame):
88
+ global frame_buffer, current_prediction_label, current_probabilities
89
 
90
  if input_frame is None:
91
+ # If no frame is received (e.g., webcam not active or disconnected)
92
  dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
93
  cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
94
  yield dummy_frame
95
+ return
96
 
 
97
  pil_image = Image.fromarray(input_frame)
 
 
98
  processed_frame_tensor = transform(pil_image)
99
  frame_buffer.append(processed_frame_tensor)
100
 
101
+ slide_window_by = 1
102
+
103
+ if len(frame_buffer) >= NUM_FRAMES:
104
+ input_tensor = torch.stack(frame_buffer[-NUM_FRAMES:], dim=0).unsqueeze(0).to(device)
105
 
106
  with torch.no_grad():
107
  outputs = model(input_tensor)
 
111
  current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
112
  current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
113
 
114
+ frame_buffer = frame_buffer[slide_window_by:]
115
+
 
 
 
 
 
 
 
116
  display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
117
 
118
  # Draw the main prediction label
 
121
  font_scale = 1.0
122
  font_thickness = 2
123
 
 
124
  cv2.putText(display_frame, current_prediction_label, (10, 40),
125
  cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
 
126
  cv2.putText(display_frame, current_prediction_label, (10, 40),
127
  cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
128
 
129
+ # Draw probabilities for all classes
130
+ y_offset = 80
131
  for label, prob in current_probabilities.items():
132
  prob_text = f"{label}: {prob:.2f}"
133
  cv2.putText(display_frame, prob_text, (10, y_offset),
134
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
135
  cv2.putText(display_frame, prob_text, (10, y_offset),
136
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA)
137
+ y_offset += 30
138
 
 
 
139
  yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
140
 
141
+ # --- 5. Gradio Blocks Interface Setup ---
142
+ with gr.Blocks(
143
+ title="Real-time Violence Detection", # Title for the browser tab
144
+ theme=gr.themes.Default(primary_hue=gr.Color(c50='#e0f7fa', c100='#b2ebf2', c200='#80deea', c300='#4dd0e1', c400='#26c6da', c500='#00bcd4', c600='#00acc1', c700='#0097a7', c800='#00838f', c900='#006064', ca50='#84ffff', ca100='#18ffff', ca200='#00e5ff', ca400='#00b8d4')) # Optional: A subtle theme change
145
+ ) as demo:
146
+ # Optional: Display a title and description clearly, even without buttons
147
+ gr.Markdown(
148
+ """
149
+ # 🎬 Real-time Violence Detection
150
+ **Live Feed with Constant Predictions**
151
+
152
+ This model analyzes your live webcam feed for violence, displaying the predicted class and probabilities on the screen.
153
+ Please grant webcam access when prompted by your browser.
154
+ """
155
+ )
156
+
157
+ with gr.Row():
158
+ # Input: Live webcam feed
159
+ # We need to set a minimum height and width to ensure the video feed is displayed reasonably
160
+ video_input = gr.Video(
161
+ sources=["webcam"],
162
+ streaming=True,
163
+ label="Live Webcam Feed",
164
+ # Optional: Set dimensions for the video display
165
+ height=480, # or None for auto
166
+ width=640 # or None for auto
167
+ )
168
+
169
+ # Output: Image component to display processed frames
170
+ video_output = gr.Image(
171
+ type="numpy",
172
+ label="Processed Feed with Predictions",
173
+ # Optional: Set dimensions to match input or your preference
174
+ height=480, # or None for auto
175
+ width=640 # or None for auto
176
+ )
177
+
178
+ # Connect the video stream directly to the prediction function
179
+ # The 'stream' event on gr.Video is triggered as new frames arrive from the webcam.
180
+ video_input.stream(
181
+ predict_live_frames, # The function to call for each frame
182
+ inputs=video_input, # Pass the video_input component itself as input
183
+ outputs=video_output # Update the video_output component
184
+ )
185
 
186
+ demo.launch()