Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import time | |
from collections import deque | |
# --- Configuration --- | |
# Your Hugging Face model repository ID | |
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" | |
# These must match the values used during your training | |
NUM_FRAMES = 16 | |
TARGET_IMAGE_HEIGHT = 224 | |
TARGET_IMAGE_WIDTH = 224 | |
# --- Load Model and Processor --- | |
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...") | |
try: | |
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
except Exception as e: | |
print(f"Error loading model from Hugging Face Hub: {e}") | |
# Handle error - exit or raise exception for Space to fail gracefully | |
exit() | |
model.eval() # Set model to evaluation mode | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"Model loaded successfully on {device}.") | |
print(f"Model's class labels: {model.config.id2label}") | |
# Initialize a global buffer for frames for the session | |
# Use a deque for efficient appending/popping from both ends | |
frame_buffer = deque(maxlen=NUM_FRAMES) | |
last_inference_time = time.time() | |
inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS) | |
current_prediction_text = "Buffering frames..." | |
def predict_video_frame(image_np_array): | |
global frame_buffer, last_inference_time, current_prediction_text | |
# Gradio sends frames as numpy arrays (RGB) | |
pil_image = Image.fromarray(image_np_array) | |
frame_buffer.append(pil_image) | |
current_time = time.time() | |
# Only perform inference if we have enough frames and it's time for a prediction | |
if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval: | |
last_inference_time = current_time | |
# Preprocess the frames. processor expects a list of PIL Images or numpy arrays | |
processed_input = processor(images=list(frame_buffer), return_tensors="pt") | |
pixel_values = processed_input.pixel_values.to(device) | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
logits = outputs.logits | |
predicted_class_id = logits.argmax(-1).item() | |
predicted_label = model.config.id2label[predicted_class_id] | |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})" | |
print(current_prediction_text) # Print to Space logs | |
# Return the current prediction text for display in the UI | |
return current_prediction_text | |
# --- Gradio Interface --- | |
# Create a streaming input for webcam | |
webcam_input = gr.Image( | |
sources=["webcam"], # Allows webcam input | |
streaming=True, # Enables continuous streaming of frames | |
shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution | |
label="Live Webcam Feed" | |
) | |
# Output text box for predictions | |
prediction_output = gr.Textbox(label="Real-time Prediction") | |
# Define the Gradio Interface | |
# We use Blocks for more control over layout if needed, but Interface works too. | |
# For simplicity, we'll stick to a basic Interface | |
# For streaming, gr.Interface.load() is more common, but let's define from scratch. | |
demo = gr.Interface( | |
fn=predict_video_frame, | |
inputs=webcam_input, | |
outputs=prediction_output, | |
live=True, # Enable live updates | |
allow_flagging="never", # Disable flagging on public demo | |
title="TimesFormer Crime Detection Live Demo", | |
description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.", | |
# You might want to add examples for file uploads if you also want to support video files. | |
# examples=["path/to/your/test_video.mp4"] # If you add video upload input | |
) | |
if __name__ == "__main__": | |
demo.launch() |