owinymarvin commited on
Commit
4784ef2
·
1 Parent(s): bc79b5b

latest changes

Browse files
Files changed (1) hide show
  1. app.py +133 -69
app.py CHANGED
@@ -12,17 +12,22 @@ from collections import deque
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
- NUM_FRAMES = 8 # Still using 8 frames, as that was your original training setup
 
 
 
 
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
19
- # --- Prediction Timing ---
20
- # How long to record (in seconds) before making a prediction
21
- RECORDING_DURATION_SECONDS = 10.0 # CHANGED: Now records for 10 seconds
22
- # How often the model should predict (after the recording duration)
23
- # Setting this to a very high number (like 9999) means it essentially predicts only once
24
- # after the recording is done until reset. Or you can leave it at 1.0 if you want it to trigger often.
25
- INFERENCE_INTERVAL_SECONDS = 1.0 # This will be the minimum time between predictions if not controlled by reset.
 
26
 
27
 
28
  # --- Load Model and Processor ---
@@ -41,40 +46,83 @@ print(f"Model loaded successfully on {device}.")
41
  print(f"Model's class labels: {model.config.id2label}")
42
 
43
  # --- Global State Variables ---
44
- # Use a global deque to store captured frames
45
- captured_frames_buffer = deque(maxlen=NUM_FRAMES)
46
- recording_start_time = None # To track when recording for a clip started
47
- last_prediction_time = time.time() # To control prediction frequency after recording
48
-
49
- # --- Functions for Gradio Interface ---
50
-
51
- def process_frame_and_predict(image_np_array):
52
- global captured_frames_buffer, recording_start_time, last_prediction_time
53
-
54
- # Initialize recording_start_time if it's the first frame for a new recording cycle
55
- if recording_start_time is None:
56
- recording_start_time = time.time()
57
- captured_frames_buffer.clear() # Clear buffer to start a new clip
58
-
59
- # Convert Gradio's numpy array (RGB) to PIL Image
60
- pil_image = Image.fromarray(image_np_array)
61
- captured_frames_buffer.append(pil_image)
 
 
 
 
 
 
 
 
 
 
62
 
63
  current_time = time.time()
64
- elapsed_recording_time = current_time - recording_start_time
65
 
66
- output_status = f"Recording: {elapsed_recording_time:.1f}/{RECORDING_DURATION_SECONDS}s | Frames: {len(captured_frames_buffer)}/{NUM_FRAMES}"
67
- prediction_text = "Recording..." # Default text while recording
68
 
69
- # Check if enough time has passed and we have enough frames
70
- if elapsed_recording_time >= RECORDING_DURATION_SECONDS and len(captured_frames_buffer) >= NUM_FRAMES:
71
- if (current_time - last_prediction_time) >= INFERENCE_INTERVAL_SECONDS: # Limit prediction frequency
72
- # --- Perform Inference ---
73
- print(f"Triggered inference on {len(captured_frames_buffer)} frames after {RECORDING_DURATION_SECONDS}s recording...")
74
- frames_for_prediction = list(captured_frames_buffer) # Take a snapshot
75
 
76
- # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
77
- processed_input = processor(images=frames_for_prediction, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  pixel_values = processed_input.pixel_values.to(device)
79
 
80
  with torch.no_grad():
@@ -85,38 +133,54 @@ def process_frame_and_predict(image_np_array):
85
  predicted_label = model.config.id2label[predicted_class_id]
86
  confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
87
 
88
- prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
89
- print(prediction_text) # Print to Space logs
90
 
91
- last_prediction_time = current_time # Update time of last successful prediction
92
-
93
- # Reset recording_start_time to allow a new recording cycle
94
- recording_start_time = None
95
- captured_frames_buffer.clear() # Clear buffer for next clip
96
  else:
97
- prediction_text = "Prediction done. Waiting for next interval..." # Message if prediction recently made
98
-
99
- return output_status, prediction_text
100
-
101
- def reset_app_state():
102
- """Resets the global state variables to start a new recording/prediction cycle."""
103
- global captured_frames_buffer, recording_start_time, last_prediction_time
104
- captured_frames_buffer.clear()
105
- recording_start_time = None
106
- last_prediction_time = time.time()
107
- print("App state reset.")
108
- # Return initial messages for the UI
109
- return "Ready to record...", "Ready for new prediction."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # --- Gradio Interface ---
112
  with gr.Blocks() as demo:
113
  gr.Markdown(
114
  f"""
115
- # TimesFormer Crime Detection Live Demo (Auto-Triggered Clip Prediction)
116
- This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed.
117
- It records **{RECORDING_DURATION_SECONDS} seconds** of video, then automatically triggers a prediction.
118
- The model processes **{NUM_FRAMES} frames** per prediction.
119
- Click 'Reset' to start a new video recording.
 
120
  Please allow webcam access.
121
  """
122
  )
@@ -128,27 +192,27 @@ with gr.Blocks() as demo:
128
  label="Live Webcam Feed"
129
  )
130
  # Textboxes for status and prediction
131
- status_output = gr.Textbox(label="Status", value="Ready to record...")
132
 
133
  # Reset Button
134
- reset_button = gr.Button("Reset / Start New Video")
135
 
136
  with gr.Column():
137
- prediction_output = gr.Textbox(label="Prediction Result", value="Recording will start automatically.")
138
 
139
  # Define actions
140
  # This continuously processes frames from the webcam
141
  webcam_input.stream(
142
- process_frame_and_predict,
143
  inputs=[webcam_input],
144
- outputs=[status_output, prediction_output] # Now outputs both status and prediction
145
  )
146
 
147
  # This triggers the reset function when the button is clicked
148
  reset_button.click(
149
- reset_app_state,
150
  inputs=[],
151
- outputs=[status_output, prediction_output] # Updates both output textboxes
152
  )
153
 
154
  if __name__ == "__main__":
 
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
+ # IMPORTANT: Your model was trained on NUM_FRAMES = 8.
16
+ # If you want to use 20 frames, this model will likely NOT perform well
17
+ # as it's a mismatch. If you truly need 20 frames, the model should be retrained with 20.
18
+ # For now, let's keep it at 8 as per your training, but we can simulate 20 captured for sampling.
19
+ MODEL_INPUT_NUM_FRAMES = 8 # This is the 'NUM_FRAMES' the model expects
20
  TARGET_IMAGE_HEIGHT = 224
21
  TARGET_IMAGE_WIDTH = 224
22
 
23
+ # --- Video Capture & Prediction Timing ---
24
+ RAW_RECORDING_DURATION_SECONDS = 10.0 # Capture raw frames for this duration for each clip
25
+ FRAMES_TO_SAMPLE_PER_CLIP = 20 # Number of frames to hypothetically sample from the raw 10s clip
26
+ # NOTE: The model will only use MODEL_INPUT_NUM_FRAMES (8) of these.
27
+
28
+ # The delay *after* a prediction is made before the next prediction cycle starts.
29
+ # Set to 120.0 seconds (2 minutes) for CPU testing. Change this for GPU.
30
+ DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # CHANGED: Variable for delay between predictions
31
 
32
 
33
  # --- Load Model and Processor ---
 
46
  print(f"Model's class labels: {model.config.id2label}")
47
 
48
  # --- Global State Variables ---
49
+ # Buffer to store raw frames from the webcam for the current 10-second segment
50
+ raw_frames_buffer = deque() # No maxlen, we manage size based on time
51
+ current_clip_start_time = time.time() # Time when the current 10-second clip started
52
+ last_prediction_completion_time = time.time() # Time when the last prediction finished
53
+
54
+ # State machine for the app's workflow
55
+ # States: "recording", "processing_delay", "predicting"
56
+ app_state = "recording"
57
+
58
+ # --- Helper function to sample frames ---
59
+ def sample_frames(frames_list, target_count):
60
+ """
61
+ Samples target_count frames evenly from a list of frames.
62
+ If frames_list has fewer than target_count, it returns all frames.
63
+ """
64
+ if not frames_list:
65
+ return []
66
+ if len(frames_list) <= target_count:
67
+ return frames_list
68
+
69
+ indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
70
+ sampled = [frames_list[i] for i in indices]
71
+ return sampled
72
+
73
+
74
+ # --- Main processing function for Gradio Stream ---
75
+ def live_predict_stream(image_np_array):
76
+ global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
77
 
78
  current_time = time.time()
79
+ pil_image = Image.fromarray(cv2.cvtColor(image_np_array, cv2.COLOR_RGB2BGR)) # Convert RGB to BGR if using cv2.putText later, otherwise RGB is fine
80
 
81
+ status_message = ""
82
+ prediction_result = ""
83
 
84
+ if app_state == "recording":
85
+ raw_frames_buffer.append(pil_image)
86
+ elapsed_recording_time = current_time - current_clip_start_time
 
 
 
87
 
88
+ if elapsed_recording_time < RAW_RECORDING_DURATION_SECONDS:
89
+ status_message = f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Total raw frames: {len(raw_frames_buffer)}"
90
+ prediction_result = "Buffering for next clip..."
91
+ else:
92
+ # Done recording, now move to predicting state
93
+ app_state = "predicting"
94
+ status_message = f"Finished recording {RAW_RECORDING_DURATION_SECONDS}s. Preparing for prediction..."
95
+ prediction_result = "Processing clip..."
96
+ print(f"DEBUG: Entering 'predicting' state. Raw frames collected: {len(raw_frames_buffer)}")
97
+
98
+ if app_state == "predicting":
99
+ # Ensure prediction logic runs only once per clip
100
+ if raw_frames_buffer: # Check if there are frames to process
101
+ print(f"DEBUG: Performing prediction.")
102
+
103
+ # 1. Sample FRAMES_TO_SAMPLE_PER_CLIP from the raw buffer
104
+ # Note: Your model was trained on MODEL_INPUT_NUM_FRAMES.
105
+ # We'll sample 20 from the raw, but then further sample 8 for the model.
106
+ sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
107
+
108
+ # 2. Select MODEL_INPUT_NUM_FRAMES from the sampled frames for the model
109
+ frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)
110
+
111
+ if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
112
+ # This should ideally not happen if RAW_RECORDING_DURATION_SECONDS is long enough
113
+ # and camera FPS is stable.
114
+ prediction_result = "Not enough frames for model input. Waiting for more..."
115
+ status_message = "Error: Not enough frames for model."
116
+ print(f"WARNING: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}")
117
+ # Reset state if we can't predict
118
+ app_state = "recording"
119
+ raw_frames_buffer.clear()
120
+ current_clip_start_time = time.time()
121
+ last_prediction_completion_time = time.time() # Reset delay counter too
122
+ return status_message, prediction_result
123
+
124
+ # Preprocess and predict
125
+ processed_input = processor(images=frames_for_model, return_tensors="pt")
126
  pixel_values = processed_input.pixel_values.to(device)
127
 
128
  with torch.no_grad():
 
133
  predicted_label = model.config.id2label[predicted_class_id]
134
  confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
135
 
136
+ prediction_result = f"Predicted: {predicted_label} ({confidence:.2f})"
137
+ print(f"DEBUG: {prediction_result}")
138
 
139
+ # Clear raw buffer as this clip has been processed
140
+ raw_frames_buffer.clear()
141
+ last_prediction_completion_time = current_time # Mark time prediction finished
142
+ app_state = "processing_delay" # Move to delay state
143
+ status_message = f"Prediction complete. Waiting for {DELAY_BETWEEN_PREDICTIONS_SECONDS}s delay."
144
  else:
145
+ # This means app_state is predicting but raw_frames_buffer is empty, should not happen in normal flow
146
+ status_message = "Waiting for frames to process..."
147
+ prediction_result = "..."
148
+
149
+ elif app_state == "processing_delay":
150
+ elapsed_delay = current_time - last_prediction_completion_time
151
+ if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS:
152
+ status_message = f"Delaying next prediction: {int(elapsed_delay)}/{DELAY_BETWEEN_PREDICTIONS_SECONDS}s"
153
+ # Keep showing the last prediction result during the delay
154
+ else:
155
+ # Delay is over, reset for next recording cycle
156
+ app_state = "recording"
157
+ current_clip_start_time = current_time # Start new recording clip
158
+ status_message = "Delay finished. Starting new recording..."
159
+ prediction_result = "Recording for next clip..."
160
+ print(f"DEBUG: Delay finished. Entering 'recording' state.")
161
+
162
+ return status_message, prediction_result
163
+
164
+ def reset_app_state_manual():
165
+ """Resets the global state variables and starts a new recording cycle immediately."""
166
+ global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
167
+ raw_frames_buffer.clear()
168
+ current_clip_start_time = time.time()
169
+ last_prediction_completion_time = time.time()
170
+ app_state = "recording" # Force state to recording
171
+ print("Manual reset: App state reset and starting new recording cycle.")
172
+ return "Ready to record...", "Ready for new prediction cycle."
173
 
174
  # --- Gradio Interface ---
175
  with gr.Blocks() as demo:
176
  gr.Markdown(
177
  f"""
178
+ # TimesFormer Crime Detection Live Demo (Segmented Auto-Prediction)
179
+ This demo continuously captures live webcam feed.
180
+ It records raw video for **{RAW_RECORDING_DURATION_SECONDS} seconds**.
181
+ From this, it samples **{FRAMES_TO_SAMPLE_PER_CLIP} frames** (for context) and then extracts **{MODEL_INPUT_NUM_FRAMES} frames**
182
+ for the TimesFormer model to make a prediction.
183
+ After each prediction, there's a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** before the next prediction cycle begins.
184
  Please allow webcam access.
185
  """
186
  )
 
192
  label="Live Webcam Feed"
193
  )
194
  # Textboxes for status and prediction
195
+ status_output = gr.Textbox(label="Current Status", value="Initializing...")
196
 
197
  # Reset Button
198
+ reset_button = gr.Button("Manual Reset / Start New Cycle Immediately")
199
 
200
  with gr.Column():
201
+ prediction_output = gr.Textbox(label="Prediction Result", value="Waiting for recording to start...")
202
 
203
  # Define actions
204
  # This continuously processes frames from the webcam
205
  webcam_input.stream(
206
+ live_predict_stream,
207
  inputs=[webcam_input],
208
+ outputs=[status_output, prediction_output]
209
  )
210
 
211
  # This triggers the reset function when the button is clicked
212
  reset_button.click(
213
+ reset_app_state_manual,
214
  inputs=[],
215
+ outputs=[status_output, prediction_output]
216
  )
217
 
218
  if __name__ == "__main__":