Spaces:
Running
Running
File size: 4,169 Bytes
8f7ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
# These must match the values used during your training
NUM_FRAMES = 16
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
print(f"Error loading model from Hugging Face Hub: {e}")
# Handle error - exit or raise exception for Space to fail gracefully
exit()
model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")
# Initialize a global buffer for frames for the session
# Use a deque for efficient appending/popping from both ends
frame_buffer = deque(maxlen=NUM_FRAMES)
last_inference_time = time.time()
inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
current_prediction_text = "Buffering frames..."
def predict_video_frame(image_np_array):
global frame_buffer, last_inference_time, current_prediction_text
# Gradio sends frames as numpy arrays (RGB)
pil_image = Image.fromarray(image_np_array)
frame_buffer.append(pil_image)
current_time = time.time()
# Only perform inference if we have enough frames and it's time for a prediction
if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
last_inference_time = current_time
# Preprocess the frames. processor expects a list of PIL Images or numpy arrays
processed_input = processor(images=list(frame_buffer), return_tensors="pt")
pixel_values = processed_input.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_id]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
print(current_prediction_text) # Print to Space logs
# Return the current prediction text for display in the UI
return current_prediction_text
# --- Gradio Interface ---
# Create a streaming input for webcam
webcam_input = gr.Image(
sources=["webcam"], # Allows webcam input
streaming=True, # Enables continuous streaming of frames
shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution
label="Live Webcam Feed"
)
# Output text box for predictions
prediction_output = gr.Textbox(label="Real-time Prediction")
# Define the Gradio Interface
# We use Blocks for more control over layout if needed, but Interface works too.
# For simplicity, we'll stick to a basic Interface
# For streaming, gr.Interface.load() is more common, but let's define from scratch.
demo = gr.Interface(
fn=predict_video_frame,
inputs=webcam_input,
outputs=prediction_output,
live=True, # Enable live updates
allow_flagging="never", # Disable flagging on public demo
title="TimesFormer Crime Detection Live Demo",
description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
# You might want to add examples for file uploads if you also want to support video files.
# examples=["path/to/your/test_video.mp4"] # If you add video upload input
)
if __name__ == "__main__":
demo.launch() |