Spaces:
Running
Running
Commit
·
4873e8b
1
Parent(s):
4784ef2
latest changes
Browse files
app.py
CHANGED
@@ -6,77 +6,49 @@ from PIL import Image
|
|
6 |
import numpy as np
|
7 |
import time
|
8 |
from collections import deque
|
|
|
|
|
9 |
|
10 |
-
# --- Configuration ---
|
11 |
-
# Your Hugging Face model repository ID
|
12 |
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
|
13 |
-
|
14 |
-
# These must match the values used during your training
|
15 |
-
# IMPORTANT: Your model was trained on NUM_FRAMES = 8.
|
16 |
-
# If you want to use 20 frames, this model will likely NOT perform well
|
17 |
-
# as it's a mismatch. If you truly need 20 frames, the model should be retrained with 20.
|
18 |
-
# For now, let's keep it at 8 as per your training, but we can simulate 20 captured for sampling.
|
19 |
-
MODEL_INPUT_NUM_FRAMES = 8 # This is the 'NUM_FRAMES' the model expects
|
20 |
TARGET_IMAGE_HEIGHT = 224
|
21 |
TARGET_IMAGE_WIDTH = 224
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
RAW_RECORDING_DURATION_SECONDS = 10.0 # Capture raw frames for this duration for each clip
|
25 |
-
FRAMES_TO_SAMPLE_PER_CLIP = 20 # Number of frames to hypothetically sample from the raw 10s clip
|
26 |
-
# NOTE: The model will only use MODEL_INPUT_NUM_FRAMES (8) of these.
|
27 |
-
|
28 |
-
# The delay *after* a prediction is made before the next prediction cycle starts.
|
29 |
-
# Set to 120.0 seconds (2 minutes) for CPU testing. Change this for GPU.
|
30 |
-
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # CHANGED: Variable for delay between predictions
|
31 |
-
|
32 |
-
|
33 |
-
# --- Load Model and Processor ---
|
34 |
-
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
|
35 |
try:
|
36 |
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
|
37 |
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
|
38 |
except Exception as e:
|
39 |
-
print(f"Error loading model
|
40 |
exit()
|
41 |
|
42 |
-
model.eval()
|
43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
model.to(device)
|
45 |
-
print(f"Model loaded
|
46 |
-
print(f"Model's class labels: {model.config.id2label}")
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
current_clip_start_time = time.time() # Time when the current 10-second clip started
|
52 |
-
last_prediction_completion_time = time.time() # Time when the last prediction finished
|
53 |
-
|
54 |
-
# State machine for the app's workflow
|
55 |
-
# States: "recording", "processing_delay", "predicting"
|
56 |
app_state = "recording"
|
57 |
|
58 |
-
# --- Helper function to sample frames ---
|
59 |
def sample_frames(frames_list, target_count):
|
60 |
-
"""
|
61 |
-
Samples target_count frames evenly from a list of frames.
|
62 |
-
If frames_list has fewer than target_count, it returns all frames.
|
63 |
-
"""
|
64 |
if not frames_list:
|
65 |
return []
|
66 |
if len(frames_list) <= target_count:
|
67 |
return frames_list
|
68 |
-
|
69 |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
|
70 |
-
sampled = [frames_list
|
71 |
return sampled
|
72 |
|
73 |
-
|
74 |
-
# --- Main processing function for Gradio Stream ---
|
75 |
def live_predict_stream(image_np_array):
|
76 |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
|
77 |
|
78 |
current_time = time.time()
|
79 |
-
pil_image = Image.fromarray(
|
80 |
|
81 |
status_message = ""
|
82 |
prediction_result = ""
|
@@ -84,136 +56,131 @@ def live_predict_stream(image_np_array):
|
|
84 |
if app_state == "recording":
|
85 |
raw_frames_buffer.append(pil_image)
|
86 |
elapsed_recording_time = current_time - current_clip_start_time
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
prediction_result = "Buffering for next clip..."
|
91 |
-
else:
|
92 |
-
# Done recording, now move to predicting state
|
93 |
app_state = "predicting"
|
94 |
-
status_message =
|
95 |
-
prediction_result = "Processing
|
96 |
-
print(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
raw_frames_buffer.clear()
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
outputs = model(pixel_values)
|
130 |
-
logits = outputs.logits
|
131 |
-
|
132 |
-
predicted_class_id = logits.argmax(-1).item()
|
133 |
-
predicted_label = model.config.id2label[predicted_class_id]
|
134 |
-
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
|
135 |
-
|
136 |
-
prediction_result = f"Predicted: {predicted_label} ({confidence:.2f})"
|
137 |
-
print(f"DEBUG: {prediction_result}")
|
138 |
-
|
139 |
-
# Clear raw buffer as this clip has been processed
|
140 |
-
raw_frames_buffer.clear()
|
141 |
-
last_prediction_completion_time = current_time # Mark time prediction finished
|
142 |
-
app_state = "processing_delay" # Move to delay state
|
143 |
-
status_message = f"Prediction complete. Waiting for {DELAY_BETWEEN_PREDICTIONS_SECONDS}s delay."
|
144 |
else:
|
145 |
-
|
146 |
-
|
147 |
-
prediction_result = "..."
|
148 |
|
149 |
elif app_state == "processing_delay":
|
150 |
elapsed_delay = current_time - last_prediction_completion_time
|
151 |
-
|
152 |
-
|
153 |
-
# Keep showing the last prediction result during the delay
|
154 |
-
else:
|
155 |
-
# Delay is over, reset for next recording cycle
|
156 |
app_state = "recording"
|
157 |
-
current_clip_start_time = current_time
|
158 |
-
status_message = "
|
159 |
-
prediction_result = "
|
160 |
-
print(
|
161 |
|
162 |
return status_message, prediction_result
|
163 |
|
164 |
def reset_app_state_manual():
|
165 |
-
"""Resets the global state variables and starts a new recording cycle immediately."""
|
166 |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
|
167 |
raw_frames_buffer.clear()
|
168 |
current_clip_start_time = time.time()
|
169 |
last_prediction_completion_time = time.time()
|
170 |
-
app_state = "recording"
|
171 |
-
print("
|
172 |
-
return "Ready to record...", "Ready for new prediction
|
173 |
|
174 |
-
# --- Gradio Interface ---
|
175 |
with gr.Blocks() as demo:
|
176 |
gr.Markdown(
|
177 |
f"""
|
178 |
-
# TimesFormer Crime Detection
|
179 |
-
This
|
180 |
-
|
181 |
-
From this, it samples **{FRAMES_TO_SAMPLE_PER_CLIP} frames** (for context) and then extracts **{MODEL_INPUT_NUM_FRAMES} frames**
|
182 |
-
for the TimesFormer model to make a prediction.
|
183 |
-
After each prediction, there's a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** before the next prediction cycle begins.
|
184 |
-
Please allow webcam access.
|
185 |
"""
|
186 |
)
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
if __name__ == "__main__":
|
219 |
demo.launch()
|
|
|
6 |
import numpy as np
|
7 |
import time
|
8 |
from collections import deque
|
9 |
+
import base64
|
10 |
+
import io
|
11 |
|
|
|
|
|
12 |
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
|
13 |
+
MODEL_INPUT_NUM_FRAMES = 8
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
TARGET_IMAGE_HEIGHT = 224
|
15 |
TARGET_IMAGE_WIDTH = 224
|
16 |
+
RAW_RECORDING_DURATION_SECONDS = 10.0
|
17 |
+
FRAMES_TO_SAMPLE_PER_CLIP = 20
|
18 |
+
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0
|
19 |
|
20 |
+
print(f"Loading model and processor from {HF_MODEL_REPO_ID}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
try:
|
22 |
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
|
23 |
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
|
24 |
except Exception as e:
|
25 |
+
print(f"Error loading model: {e}")
|
26 |
exit()
|
27 |
|
28 |
+
model.eval()
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
model.to(device)
|
31 |
+
print(f"Model loaded on {device}.")
|
|
|
32 |
|
33 |
+
raw_frames_buffer = deque()
|
34 |
+
current_clip_start_time = time.time()
|
35 |
+
last_prediction_completion_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
36 |
app_state = "recording"
|
37 |
|
|
|
38 |
def sample_frames(frames_list, target_count):
|
|
|
|
|
|
|
|
|
39 |
if not frames_list:
|
40 |
return []
|
41 |
if len(frames_list) <= target_count:
|
42 |
return frames_list
|
|
|
43 |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
|
44 |
+
sampled = [frames_list(int(i)) for i in indices] # Corrected sampling
|
45 |
return sampled
|
46 |
|
|
|
|
|
47 |
def live_predict_stream(image_np_array):
|
48 |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
|
49 |
|
50 |
current_time = time.time()
|
51 |
+
pil_image = Image.fromarray(image_np_array)
|
52 |
|
53 |
status_message = ""
|
54 |
prediction_result = ""
|
|
|
56 |
if app_state == "recording":
|
57 |
raw_frames_buffer.append(pil_image)
|
58 |
elapsed_recording_time = current_time - current_clip_start_time
|
59 |
+
status_message = f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}"
|
60 |
+
prediction_result = "Buffering..."
|
61 |
+
if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS:
|
|
|
|
|
|
|
62 |
app_state = "predicting"
|
63 |
+
status_message = "Preparing to predict..."
|
64 |
+
prediction_result = "Processing..."
|
65 |
+
print("DEBUG: Transitioning to 'predicting' state.")
|
66 |
+
|
67 |
+
elif app_state == "predicting":
|
68 |
+
if raw_frames_buffer:
|
69 |
+
print("DEBUG: Starting prediction.")
|
70 |
+
try:
|
71 |
+
sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
|
72 |
+
frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)
|
73 |
+
|
74 |
+
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
|
75 |
+
prediction_result = "Error: Not enough frames for model."
|
76 |
+
status_message = "Error during frame sampling."
|
77 |
+
app_state = "recording"
|
78 |
+
raw_frames_buffer.clear()
|
79 |
+
current_clip_start_time = time.time()
|
80 |
+
last_prediction_completion_time = time.time()
|
81 |
+
return status_message, prediction_result
|
82 |
+
|
83 |
+
processed_input = processor(images=frames_for_model, return_tensors="pt")
|
84 |
+
pixel_values = processed_input.pixel_values.to(device)
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
outputs = model(pixel_values)
|
88 |
+
logits = outputs.logits
|
89 |
+
|
90 |
+
predicted_class_id = logits.argmax(-1).item()
|
91 |
+
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") # Handle potential missing label
|
92 |
+
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
|
93 |
+
|
94 |
+
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
|
95 |
+
status_message = "Prediction complete."
|
96 |
+
print(f"DEBUG: Prediction Result: {prediction_result}")
|
97 |
+
|
98 |
raw_frames_buffer.clear()
|
99 |
+
last_prediction_completion_time = current_time
|
100 |
+
app_state = "processing_delay"
|
101 |
+
print("DEBUG: Transitioning to 'processing_delay' state.")
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
prediction_result = f"Error during prediction: {e}"
|
105 |
+
status_message = "Prediction error."
|
106 |
+
print(f"ERROR during prediction: {e}")
|
107 |
+
app_state = "processing_delay" # Move to delay to avoid continuous errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
else:
|
109 |
+
status_message = "Waiting for frames..."
|
110 |
+
prediction_result = "..."
|
|
|
111 |
|
112 |
elif app_state == "processing_delay":
|
113 |
elapsed_delay = current_time - last_prediction_completion_time
|
114 |
+
status_message = f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s"
|
115 |
+
if elapsed_delay >= DELAY_BETWEEN_PREDICTIONS_SECONDS:
|
|
|
|
|
|
|
116 |
app_state = "recording"
|
117 |
+
current_clip_start_time = current_time
|
118 |
+
status_message = "Starting new recording..."
|
119 |
+
prediction_result = "Ready..."
|
120 |
+
print("DEBUG: Transitioning back to 'recording' state.")
|
121 |
|
122 |
return status_message, prediction_result
|
123 |
|
124 |
def reset_app_state_manual():
|
|
|
125 |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
|
126 |
raw_frames_buffer.clear()
|
127 |
current_clip_start_time = time.time()
|
128 |
last_prediction_completion_time = time.time()
|
129 |
+
app_state = "recording"
|
130 |
+
print("DEBUG: Manual reset triggered.")
|
131 |
+
return "Ready to record...", "Ready for new prediction."
|
132 |
|
|
|
133 |
with gr.Blocks() as demo:
|
134 |
gr.Markdown(
|
135 |
f"""
|
136 |
+
# TimesFormer Crime Detection - Hugging Face Space Host
|
137 |
+
This Space hosts the `owinymarvin/timesformer-crime-detection` model.
|
138 |
+
Live webcam demo with recording and prediction phases.
|
|
|
|
|
|
|
|
|
139 |
"""
|
140 |
)
|
141 |
+
|
142 |
+
with gr.Tab("Live Webcam Demo"):
|
143 |
+
gr.Markdown(
|
144 |
+
f"""
|
145 |
+
Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**,
|
146 |
+
then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards.
|
147 |
+
"""
|
148 |
+
)
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
webcam_input = gr.Image(
|
152 |
+
sources=["webcam"],
|
153 |
+
streaming=True,
|
154 |
+
label="Live Webcam Feed"
|
155 |
+
)
|
156 |
+
status_output = gr.Textbox(label="Current Status", value="Initializing...")
|
157 |
+
reset_button = gr.Button("Reset / Start New Cycle")
|
158 |
+
with gr.Column():
|
159 |
+
prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...")
|
160 |
+
|
161 |
+
webcam_input.stream(
|
162 |
+
live_predict_stream,
|
163 |
+
inputs=[webcam_input],
|
164 |
+
outputs=[status_output, prediction_output]
|
165 |
+
)
|
166 |
+
reset_button.click(
|
167 |
+
reset_app_state_manual,
|
168 |
+
inputs=[],
|
169 |
+
outputs=[status_output, prediction_output]
|
170 |
+
)
|
171 |
+
|
172 |
+
with gr.Tab("API Endpoint for External Clients"):
|
173 |
+
gr.Markdown(
|
174 |
+
"""
|
175 |
+
Use this API endpoint to send base64-encoded frames for prediction.
|
176 |
+
"""
|
177 |
+
)
|
178 |
+
gr.Interface(
|
179 |
+
fn=lambda x: "API endpoint is active",
|
180 |
+
inputs=gr.Textbox(label="Input (Base64 JSON)"),
|
181 |
+
outputs=gr.Textbox(label="Status"),
|
182 |
+
title="API Status (Details in app.py)" # Minimal UI for API tab
|
183 |
+
)
|
184 |
|
185 |
if __name__ == "__main__":
|
186 |
demo.launch()
|