owinymarvin's picture
first commit
8f7ac8f
raw
history blame
4.17 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
# These must match the values used during your training
NUM_FRAMES = 16
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
print(f"Error loading model from Hugging Face Hub: {e}")
# Handle error - exit or raise exception for Space to fail gracefully
exit()
model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")
# Initialize a global buffer for frames for the session
# Use a deque for efficient appending/popping from both ends
frame_buffer = deque(maxlen=NUM_FRAMES)
last_inference_time = time.time()
inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
current_prediction_text = "Buffering frames..."
def predict_video_frame(image_np_array):
global frame_buffer, last_inference_time, current_prediction_text
# Gradio sends frames as numpy arrays (RGB)
pil_image = Image.fromarray(image_np_array)
frame_buffer.append(pil_image)
current_time = time.time()
# Only perform inference if we have enough frames and it's time for a prediction
if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
last_inference_time = current_time
# Preprocess the frames. processor expects a list of PIL Images or numpy arrays
processed_input = processor(images=list(frame_buffer), return_tensors="pt")
pixel_values = processed_input.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_id]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
print(current_prediction_text) # Print to Space logs
# Return the current prediction text for display in the UI
return current_prediction_text
# --- Gradio Interface ---
# Create a streaming input for webcam
webcam_input = gr.Image(
sources=["webcam"], # Allows webcam input
streaming=True, # Enables continuous streaming of frames
shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution
label="Live Webcam Feed"
)
# Output text box for predictions
prediction_output = gr.Textbox(label="Real-time Prediction")
# Define the Gradio Interface
# We use Blocks for more control over layout if needed, but Interface works too.
# For simplicity, we'll stick to a basic Interface
# For streaming, gr.Interface.load() is more common, but let's define from scratch.
demo = gr.Interface(
fn=predict_video_frame,
inputs=webcam_input,
outputs=prediction_output,
live=True, # Enable live updates
allow_flagging="never", # Disable flagging on public demo
title="TimesFormer Crime Detection Live Demo",
description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
# You might want to add examples for file uploads if you also want to support video files.
# examples=["path/to/your/test_video.mp4"] # If you add video upload input
)
if __name__ == "__main__":
demo.launch()