Spaces:
Sleeping
Sleeping
File size: 8,225 Bytes
8f7ac8f 4873e8b 8f7ac8f 4873e8b 8f7ac8f 4873e8b 8f7ac8f 4873e8b 8f7ac8f 4873e8b 8f7ac8f 4873e8b 8f7ac8f 4873e8b 8f7ac8f 4873e8b 4784ef2 c450b97 4784ef2 9d0ee1c 6f472e5 4873e8b 9d0ee1c 4784ef2 8f7ac8f 4784ef2 4873e8b 4784ef2 4873e8b c450b97 4873e8b c450b97 4873e8b 4784ef2 4873e8b 6f472e5 4873e8b 4784ef2 4873e8b 4784ef2 4873e8b 4784ef2 4873e8b 8f7ac8f 9d0ee1c 4873e8b 9d0ee1c 4873e8b c450b97 4873e8b c450b97 4873e8b c450b97 8f7ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
import base64
import io
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
MODEL_INPUT_NUM_FRAMES = 8
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
RAW_RECORDING_DURATION_SECONDS = 10.0
FRAMES_TO_SAMPLE_PER_CLIP = 20
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0
print(f"Loading model and processor from {HF_MODEL_REPO_ID}...")
try:
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
print(f"Error loading model: {e}")
exit()
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}.")
raw_frames_buffer = deque()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
app_state = "recording"
def sample_frames(frames_list, target_count):
if not frames_list:
return []
if len(frames_list) <= target_count:
return frames_list
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
# FIX: Corrected list indexing from () to []
sampled = [frames_list[int(i)] for i in indices]
return sampled
def live_predict_stream(image_np_array):
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
current_time = time.time()
pil_image = Image.fromarray(image_np_array)
status_message = ""
prediction_result = ""
if app_state == "recording":
raw_frames_buffer.append(pil_image)
elapsed_recording_time = current_time - current_clip_start_time
status_message = f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}"
prediction_result = "Buffering..."
if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS:
app_state = "predicting"
status_message = "Preparing to predict..."
prediction_result = "Processing..."
print("DEBUG: Transitioning to 'predicting' state.")
elif app_state == "predicting":
if raw_frames_buffer:
print("DEBUG: Starting prediction.")
try:
sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
prediction_result = "Error: Not enough frames for model."
status_message = "Error during frame sampling."
print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}")
app_state = "recording" # Reset to recording state
raw_frames_buffer.clear()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
return status_message, prediction_result
processed_input = processor(images=frames_for_model, return_tensors="pt")
pixel_values = processed_input.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
status_message = "Prediction complete."
print(f"DEBUG: Prediction Result: {prediction_result}")
raw_frames_buffer.clear()
last_prediction_completion_time = current_time
app_state = "processing_delay"
print("DEBUG: Transitioning to 'processing_delay' state.")
except Exception as e:
prediction_result = f"Error during prediction: {e}"
status_message = "Prediction error."
print(f"ERROR during prediction: {e}")
app_state = "processing_delay" # Move to delay to avoid continuous errors
else:
status_message = "Waiting for frames..."
prediction_result = "..."
elif app_state == "processing_delay":
elapsed_delay = current_time - last_prediction_completion_time
status_message = f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s"
if elapsed_delay >= DELAY_BETWEEN_PREDICTIONS_SECONDS:
app_state = "recording"
current_clip_start_time = current_time
status_message = "Starting new recording..."
prediction_result = "Ready..."
print("DEBUG: Transitioning back to 'recording' state.")
return status_message, prediction_result
def reset_app_state_manual():
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
raw_frames_buffer.clear()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
app_state = "recording"
print("DEBUG: Manual reset triggered.")
return "Ready to record...", "Ready for new prediction."
with gr.Blocks() as demo:
gr.Markdown(
f"""
# TimesFormer Crime Detection - Hugging Face Space Host
This Space hosts the `owinymarvin/timesformer-crime-detection` model.
Live webcam demo with recording and prediction phases.
"""
)
with gr.Tab("Live Webcam Demo"):
gr.Markdown(
f"""
Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**,
then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards.
"""
)
with gr.Row():
with gr.Column():
webcam_input = gr.Image(
sources=["webcam"],
streaming=True,
label="Live Webcam Feed"
)
status_output = gr.Textbox(label="Current Status", value="Initializing...")
reset_button = gr.Button("Reset / Start New Cycle")
with gr.Column():
prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...")
webcam_input.stream(
live_predict_stream,
inputs=[webcam_input],
outputs=[status_output, prediction_output]
)
reset_button.click(
reset_app_state_manual,
inputs=[],
outputs=[status_output, prediction_output]
)
with gr.Tab("API Endpoint for External Clients"):
gr.Markdown(
"""
Use this API endpoint to send base64-encoded frames for prediction.
"""
)
# Re-adding a slightly more representative API interface
# Gradio's automatic API documentation will use this to show inputs/outputs
gr.Interface(
fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.",
inputs=gr.Json(label="List of Base64-encoded image strings"),
outputs=gr.Textbox(label="API Response"),
live=False,
allow_flagging="never" # For API endpoints, flagging is usually not desired
)
# Note: The actual `predict_from_frames_api` function is defined above,
# but for a clean API tab, we can use a dummy interface here that Gradio will
# use to generate the interactive API documentation. The actual API call
# from your local script directly targets the /run/predict_from_frames_api endpoint.
if __name__ == "__main__":
demo.launch() |