Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import time | |
from collections import deque | |
import base64 | |
import io | |
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
MODEL_INPUT_NUM_FRAMES = 8 | |
TARGET_IMAGE_HEIGHT = 224 | |
TARGET_IMAGE_WIDTH = 224 | |
RAW_RECORDING_DURATION_SECONDS = 10.0 | |
FRAMES_TO_SAMPLE_PER_CLIP = 20 | |
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 | |
print(f"Loading model and processor from {HF_MODEL_REPO_ID}...") | |
try: | |
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
exit() | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"Model loaded on {device}.") | |
raw_frames_buffer = deque() | |
current_clip_start_time = time.time() | |
last_prediction_completion_time = time.time() | |
app_state = "recording" | |
def sample_frames(frames_list, target_count): | |
if not frames_list: | |
return [] | |
if len(frames_list) <= target_count: | |
return frames_list | |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int) | |
# FIX: Corrected list indexing from () to [] | |
sampled = [frames_list[int(i)] for i in indices] | |
return sampled | |
def live_predict_stream(image_np_array): | |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
current_time = time.time() | |
pil_image = Image.fromarray(image_np_array) | |
status_message = "" | |
prediction_result = "" | |
if app_state == "recording": | |
raw_frames_buffer.append(pil_image) | |
elapsed_recording_time = current_time - current_clip_start_time | |
status_message = f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}" | |
prediction_result = "Buffering..." | |
if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS: | |
app_state = "predicting" | |
status_message = "Preparing to predict..." | |
prediction_result = "Processing..." | |
print("DEBUG: Transitioning to 'predicting' state.") | |
elif app_state == "predicting": | |
if raw_frames_buffer: | |
print("DEBUG: Starting prediction.") | |
try: | |
sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP) | |
frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES) | |
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES: | |
prediction_result = "Error: Not enough frames for model." | |
status_message = "Error during frame sampling." | |
print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}") | |
app_state = "recording" # Reset to recording state | |
raw_frames_buffer.clear() | |
current_clip_start_time = time.time() | |
last_prediction_completion_time = time.time() | |
return status_message, prediction_result | |
processed_input = processor(images=frames_for_model, return_tensors="pt") | |
pixel_values = processed_input.pixel_values.to(device) | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
logits = outputs.logits | |
predicted_class_id = logits.argmax(-1).item() | |
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") | |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" | |
status_message = "Prediction complete." | |
print(f"DEBUG: Prediction Result: {prediction_result}") | |
raw_frames_buffer.clear() | |
last_prediction_completion_time = current_time | |
app_state = "processing_delay" | |
print("DEBUG: Transitioning to 'processing_delay' state.") | |
except Exception as e: | |
prediction_result = f"Error during prediction: {e}" | |
status_message = "Prediction error." | |
print(f"ERROR during prediction: {e}") | |
app_state = "processing_delay" # Move to delay to avoid continuous errors | |
else: | |
status_message = "Waiting for frames..." | |
prediction_result = "..." | |
elif app_state == "processing_delay": | |
elapsed_delay = current_time - last_prediction_completion_time | |
status_message = f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s" | |
if elapsed_delay >= DELAY_BETWEEN_PREDICTIONS_SECONDS: | |
app_state = "recording" | |
current_clip_start_time = current_time | |
status_message = "Starting new recording..." | |
prediction_result = "Ready..." | |
print("DEBUG: Transitioning back to 'recording' state.") | |
return status_message, prediction_result | |
def reset_app_state_manual(): | |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
raw_frames_buffer.clear() | |
current_clip_start_time = time.time() | |
last_prediction_completion_time = time.time() | |
app_state = "recording" | |
print("DEBUG: Manual reset triggered.") | |
return "Ready to record...", "Ready for new prediction." | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
f""" | |
# TimesFormer Crime Detection - Hugging Face Space Host | |
This Space hosts the `owinymarvin/timesformer-crime-detection` model. | |
Live webcam demo with recording and prediction phases. | |
""" | |
) | |
with gr.Tab("Live Webcam Demo"): | |
gr.Markdown( | |
f""" | |
Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**, | |
then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
webcam_input = gr.Image( | |
sources=["webcam"], | |
streaming=True, | |
label="Live Webcam Feed" | |
) | |
status_output = gr.Textbox(label="Current Status", value="Initializing...") | |
reset_button = gr.Button("Reset / Start New Cycle") | |
with gr.Column(): | |
prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...") | |
webcam_input.stream( | |
live_predict_stream, | |
inputs=[webcam_input], | |
outputs=[status_output, prediction_output] | |
) | |
reset_button.click( | |
reset_app_state_manual, | |
inputs=[], | |
outputs=[status_output, prediction_output] | |
) | |
with gr.Tab("API Endpoint for External Clients"): | |
gr.Markdown( | |
""" | |
Use this API endpoint to send base64-encoded frames for prediction. | |
""" | |
) | |
# Re-adding a slightly more representative API interface | |
# Gradio's automatic API documentation will use this to show inputs/outputs | |
gr.Interface( | |
fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.", | |
inputs=gr.Json(label="List of Base64-encoded image strings"), | |
outputs=gr.Textbox(label="API Response"), | |
live=False, | |
allow_flagging="never" # For API endpoints, flagging is usually not desired | |
) | |
# Note: The actual `predict_from_frames_api` function is defined above, | |
# but for a clean API tab, we can use a dummy interface here that Gradio will | |
# use to generate the interactive API documentation. The actual API call | |
# from your local script directly targets the /run/predict_from_frames_api endpoint. | |
if __name__ == "__main__": | |
demo.launch() |