Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import time | |
from collections import deque | |
# --- Configuration --- | |
# Your Hugging Face model repository ID | |
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
# These must match the values used during your training | |
NUM_FRAMES = 16 # Still expecting 16 frames for a batch | |
TARGET_IMAGE_HEIGHT = 224 | |
TARGET_IMAGE_WIDTH = 224 | |
# --- Load Model and Processor --- | |
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...") | |
try: | |
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
except Exception as e: | |
print(f"Error loading model from Hugging Face Hub: {e}") | |
# Handle error - exit or raise exception for Space to fail gracefully | |
exit() | |
model.eval() # Set model to evaluation mode | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"Model loaded successfully on {device}.") | |
print(f"Model's class labels: {model.config.id2label}") | |
# Initialize a global buffer for frames that the webcam continuously captures | |
# This buffer will hold the *latest* NUM_FRAMES. | |
# We use a global variable to persist state across Gradio calls. | |
captured_frames_buffer = deque(maxlen=NUM_FRAMES) | |
# This flag will control the 5-minute wait (if still needed for testing) | |
wait_duration_seconds = 300 # 5 minutes | |
# --- Function to continuously capture frames (without immediate processing) --- | |
def capture_frame_into_buffer(image_np_array): | |
global captured_frames_buffer | |
# Convert Gradio's numpy array (RGB) to PIL Image | |
pil_image = Image.fromarray(image_np_array) | |
captured_frames_buffer.append(pil_image) | |
# Return a message showing how many frames are buffered | |
return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}" | |
# --- Function to trigger prediction with the buffered frames --- | |
def make_prediction_from_buffer(): | |
global captured_frames_buffer | |
if len(captured_frames_buffer) < NUM_FRAMES: | |
return "Not enough frames buffered yet. Please capture more frames." | |
# Take a snapshot of the current frames in the buffer for prediction | |
# Convert deque to a list for the processor | |
frames_for_prediction = list(captured_frames_buffer) | |
# --- Perform Inference --- | |
print(f"Triggered inference on {len(frames_for_prediction)} frames...") | |
processed_input = processor(images=frames_for_prediction, return_tensors="pt") | |
pixel_values = processed_input.pixel_values.to(device) | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
logits = outputs.logits | |
predicted_class_id = logits.argmax(-1).item() | |
predicted_label = model.config.id2label[predicted_class_id] | |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})" | |
print(prediction_text) # Print to Space logs | |
# Clear the buffer after prediction if you want to capture a *new* set of frames for the next click | |
# captured_frames_buffer.clear() | |
# If you *don't* clear, the next click will re-predict on the same last 16 frames. | |
# --- Introduce the artificial 5-minute wait (if still desired) --- | |
# This will pause the *return* from this function, effectively blocking the UI update | |
# If you remove this, the prediction will show immediately. | |
# print(f"Initiating {wait_duration_seconds} second wait...") | |
# time.sleep(wait_duration_seconds) | |
# print("Wait finished.") | |
return prediction_text | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
f""" | |
# TimesFormer Crime Detection Live Demo (Manual Trigger) | |
This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. | |
It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**. | |
The model requires **{NUM_FRAMES} frames** for a prediction. | |
Please allow webcam access. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
webcam_input = gr.Image( | |
sources=["webcam"], | |
streaming=True, | |
label="Live Webcam Feed" | |
) | |
# This textbox will show the buffering status dynamically | |
buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}") | |
# Button to trigger prediction | |
predict_button = gr.Button("Predict Latest Frames") | |
with gr.Column(): | |
prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.") | |
# Define actions | |
# This continuously updates the buffer_status as frames come in | |
webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status]) | |
# This triggers the prediction when the button is clicked | |
predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output]) | |
if __name__ == "__main__": | |
demo.launch() |