owinymarvin commited on
Commit
6f472e5
·
1 Parent(s): 9d0ee1c

latest changes

Browse files
Files changed (1) hide show
  1. app.py +79 -58
app.py CHANGED
@@ -12,10 +12,19 @@ from collections import deque
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
- NUM_FRAMES = 16 # Still expecting 16 frames for a batch
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
 
 
 
 
 
 
 
 
 
19
  # --- Load Model and Processor ---
20
  print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
21
  try:
@@ -23,7 +32,6 @@ try:
23
  model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
24
  except Exception as e:
25
  print(f"Error loading model from Hugging Face Hub: {e}")
26
- # Handle error - exit or raise exception for Space to fail gracefully
27
  exit()
28
 
29
  model.eval() # Set model to evaluation mode
@@ -32,75 +40,81 @@ model.to(device)
32
  print(f"Model loaded successfully on {device}.")
33
  print(f"Model's class labels: {model.config.id2label}")
34
 
35
- # Initialize a global buffer for frames that the webcam continuously captures
36
- # This buffer will hold the *latest* NUM_FRAMES.
37
- # We use a global variable to persist state across Gradio calls.
38
  captured_frames_buffer = deque(maxlen=NUM_FRAMES)
 
 
 
 
39
 
40
- # This flag will control the 5-minute wait (if still needed for testing)
41
- wait_duration_seconds = 300 # 5 minutes
42
 
43
- # --- Function to continuously capture frames (without immediate processing) ---
44
- def capture_frame_into_buffer(image_np_array):
45
- global captured_frames_buffer
 
46
 
47
  # Convert Gradio's numpy array (RGB) to PIL Image
48
  pil_image = Image.fromarray(image_np_array)
49
  captured_frames_buffer.append(pil_image)
50
 
51
- # Return a message showing how many frames are buffered
52
- return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}"
53
 
 
 
54
 
55
- # --- Function to trigger prediction with the buffered frames ---
56
- def make_prediction_from_buffer():
57
- global captured_frames_buffer
 
 
 
58
 
59
- if len(captured_frames_buffer) < NUM_FRAMES:
60
- return "Not enough frames buffered yet. Please capture more frames."
 
61
 
62
- # Take a snapshot of the current frames in the buffer for prediction
63
- # Convert deque to a list for the processor
64
- frames_for_prediction = list(captured_frames_buffer)
65
 
66
- # --- Perform Inference ---
67
- print(f"Triggered inference on {len(frames_for_prediction)} frames...")
68
- processed_input = processor(images=frames_for_prediction, return_tensors="pt")
69
- pixel_values = processed_input.pixel_values.to(device)
70
 
71
- with torch.no_grad():
72
- outputs = model(pixel_values)
73
- logits = outputs.logits
74
 
75
- predicted_class_id = logits.argmax(-1).item()
76
- predicted_label = model.config.id2label[predicted_class_id]
77
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
78
-
79
- prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
80
- print(prediction_text) # Print to Space logs
81
-
82
- # Clear the buffer after prediction if you want to capture a *new* set of frames for the next click
83
- # captured_frames_buffer.clear()
84
- # If you *don't* clear, the next click will re-predict on the same last 16 frames.
85
 
86
- # --- Introduce the artificial 5-minute wait (if still desired) ---
87
- # This will pause the *return* from this function, effectively blocking the UI update
88
- # If you remove this, the prediction will show immediately.
89
- # print(f"Initiating {wait_duration_seconds} second wait...")
90
- # time.sleep(wait_duration_seconds)
91
- # print("Wait finished.")
92
-
93
- return prediction_text
94
 
 
 
 
 
 
 
 
 
95
 
96
  # --- Gradio Interface ---
97
  with gr.Blocks() as demo:
98
  gr.Markdown(
99
  f"""
100
- # TimesFormer Crime Detection Live Demo (Manual Trigger)
101
  This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed.
102
- It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**.
103
- The model requires **{NUM_FRAMES} frames** for a prediction.
104
  Please allow webcam access.
105
  """
106
  )
@@ -111,22 +125,29 @@ with gr.Blocks() as demo:
111
  streaming=True,
112
  label="Live Webcam Feed"
113
  )
114
- # This textbox will show the buffering status dynamically
115
- buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}")
116
 
117
- # Button to trigger prediction
118
- predict_button = gr.Button("Predict Latest Frames")
119
 
120
  with gr.Column():
121
- prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.")
122
 
123
  # Define actions
124
- # This continuously updates the buffer_status as frames come in
125
- webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status])
 
 
 
 
126
 
127
- # This triggers the prediction when the button is clicked
128
- predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output])
129
-
 
 
 
130
 
131
  if __name__ == "__main__":
132
  demo.launch()
 
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
+ NUM_FRAMES = 8 # Changed back to 8 as that was your original training setup for this model
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
19
+ # --- Prediction Timing ---
20
+ # How long to record (in seconds) before making a prediction
21
+ RECORDING_DURATION_SECONDS = 3.0
22
+ # How often the model should predict (after the recording duration)
23
+ # Setting this to a very high number (like 9999) means it essentially predicts only once
24
+ # after the recording is done until reset. Or you can leave it at 1.0 if you want it to trigger often.
25
+ INFERENCE_INTERVAL_SECONDS = 1.0 # This will be the minimum time between predictions if not controlled by reset.
26
+
27
+
28
  # --- Load Model and Processor ---
29
  print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
30
  try:
 
32
  model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
33
  except Exception as e:
34
  print(f"Error loading model from Hugging Face Hub: {e}")
 
35
  exit()
36
 
37
  model.eval() # Set model to evaluation mode
 
40
  print(f"Model loaded successfully on {device}.")
41
  print(f"Model's class labels: {model.config.id2label}")
42
 
43
+ # --- Global State Variables ---
44
+ # Use a global deque to store captured frames
 
45
  captured_frames_buffer = deque(maxlen=NUM_FRAMES)
46
+ recording_start_time = None # To track when recording for a clip started
47
+ last_prediction_time = time.time() # To control prediction frequency after recording
48
+
49
+ # --- Functions for Gradio Interface ---
50
 
51
+ def process_frame_and_predict(image_np_array):
52
+ global captured_frames_buffer, recording_start_time, last_prediction_time
53
 
54
+ # Initialize recording_start_time if it's the first frame for a new recording cycle
55
+ if recording_start_time is None:
56
+ recording_start_time = time.time()
57
+ captured_frames_buffer.clear() # Clear buffer to start a new clip
58
 
59
  # Convert Gradio's numpy array (RGB) to PIL Image
60
  pil_image = Image.fromarray(image_np_array)
61
  captured_frames_buffer.append(pil_image)
62
 
63
+ current_time = time.time()
64
+ elapsed_recording_time = current_time - recording_start_time
65
 
66
+ output_status = f"Recording: {elapsed_recording_time:.1f}/{RECORDING_DURATION_SECONDS}s | Frames: {len(captured_frames_buffer)}/{NUM_FRAMES}"
67
+ prediction_text = "Recording..." # Default text while recording
68
 
69
+ # Check if enough time has passed and we have enough frames
70
+ if elapsed_recording_time >= RECORDING_DURATION_SECONDS and len(captured_frames_buffer) >= NUM_FRAMES:
71
+ if (current_time - last_prediction_time) >= INFERENCE_INTERVAL_SECONDS: # Limit prediction frequency
72
+ # --- Perform Inference ---
73
+ print(f"Triggered inference on {len(captured_frames_buffer)} frames after {RECORDING_DURATION_SECONDS}s recording...")
74
+ frames_for_prediction = list(captured_frames_buffer) # Take a snapshot
75
 
76
+ # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
77
+ processed_input = processor(images=frames_for_prediction, return_tensors="pt")
78
+ pixel_values = processed_input.pixel_values.to(device)
79
 
80
+ with torch.no_grad():
81
+ outputs = model(pixel_values)
82
+ logits = outputs.logits
83
 
84
+ predicted_class_id = logits.argmax(-1).item()
85
+ predicted_label = model.config.id2label[predicted_class_id]
86
+ confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
 
87
 
88
+ prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
89
+ print(prediction_text) # Print to Space logs
 
90
 
91
+ last_prediction_time = current_time # Update time of last successful prediction
92
+
93
+ # Reset recording_start_time to allow a new recording cycle
94
+ recording_start_time = None
95
+ captured_frames_buffer.clear() # Clear buffer for next clip
96
+ else:
97
+ prediction_text = "Prediction done. Waiting for next interval..." # Message if prediction recently made
 
 
 
98
 
99
+ return output_status, prediction_text
 
 
 
 
 
 
 
100
 
101
+ def reset_app_state():
102
+ """Resets the global state variables to start a new recording/prediction cycle."""
103
+ global captured_frames_buffer, recording_start_time, last_prediction_time
104
+ captured_frames_buffer.clear()
105
+ recording_start_time = None
106
+ last_prediction_time = time.time()
107
+ print("App state reset.")
108
+ return "Ready to record...", "Ready for new prediction."
109
 
110
  # --- Gradio Interface ---
111
  with gr.Blocks() as demo:
112
  gr.Markdown(
113
  f"""
114
+ # TimesFormer Crime Detection Live Demo (Auto-Triggered Clip Prediction)
115
  This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed.
116
+ It records **{RECORDING_DURATION_SECONDS} seconds** of video, then automatically triggers a prediction.
117
+ The model processes **{NUM_FRAMES} frames** per prediction.
118
  Please allow webcam access.
119
  """
120
  )
 
125
  streaming=True,
126
  label="Live Webcam Feed"
127
  )
128
+ # Textboxes for status and prediction
129
+ status_output = gr.Textbox(label="Status", value="Ready to record...")
130
 
131
+ # Reset Button
132
+ reset_button = gr.Button("Reset / Start New Recording Cycle")
133
 
134
  with gr.Column():
135
+ prediction_output = gr.Textbox(label="Prediction Result", value="Recording will start automatically.")
136
 
137
  # Define actions
138
+ # This continuously processes frames from the webcam
139
+ webcam_input.stream(
140
+ process_frame_and_predict,
141
+ inputs=[webcam_input],
142
+ outputs=[status_output, prediction_output] # Now outputs both status and prediction
143
+ )
144
 
145
+ # This triggers the reset function when the button is clicked
146
+ reset_button.click(
147
+ reset_app_state,
148
+ inputs=[],
149
+ outputs=[status_output, prediction_output] # Updates both output textboxes
150
+ )
151
 
152
  if __name__ == "__main__":
153
  demo.launch()