import gradio as gr import torch from transformers import AutoImageProcessor, TimesformerForVideoClassification import cv2 from PIL import Image import numpy as np import time from collections import deque # --- Configuration --- # Your Hugging Face model repository ID HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" # These must match the values used during your training NUM_FRAMES = 8 TARGET_IMAGE_HEIGHT = 224 TARGET_IMAGE_WIDTH = 224 # --- Load Model and Processor --- print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...") try: processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) except Exception as e: print(f"Error loading model from Hugging Face Hub: {e}") # Handle error - exit or raise exception for Space to fail gracefully exit() model.eval() # Set model to evaluation mode device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f"Model loaded successfully on {device}.") print(f"Model's class labels: {model.config.id2label}") # Initialize a global buffer for frames for the session # Use a deque for efficient appending/popping from both ends frame_buffer = deque(maxlen=NUM_FRAMES) last_inference_time = time.time() inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS) current_prediction_text = "Buffering frames..." # Initialize global text def predict_video_frame(image_np_array): global frame_buffer, last_inference_time, current_prediction_text # Gradio sends frames as numpy arrays (RGB) # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH pil_image = Image.fromarray(image_np_array) frame_buffer.append(pil_image) current_time = time.time() # Only perform inference if we have enough frames and it's time for a prediction if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval: last_inference_time = current_time # Preprocess the frames. processor expects a list of PIL Images or numpy arrays # It will handle resizing and normalization based on its config processed_input = processor(images=list(frame_buffer), return_tensors="pt") pixel_values = processed_input.pixel_values.to(device) with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits predicted_class_id = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_id] confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})" print(current_prediction_text) # Print to Space logs # Return the current prediction text for display in the UI # Gradio's streaming will update this textbox asynchronously return current_prediction_text # --- Gradio Interface --- # Create a streaming input for webcam webcam_input = gr.Image( sources=["webcam"], # Allows webcam input streaming=True, # Enables continuous streaming of frames # REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError label="Live Webcam Feed" ) # Output text box for predictions prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...") # Define the Gradio Interface demo = gr.Interface( fn=predict_video_frame, inputs=webcam_input, outputs=prediction_output, live=True, # Enable live updates allow_flagging="never", # Disable flagging on public demo title="TimesFormer Crime Detection Live Demo", description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.", ) if __name__ == "__main__": demo.launch()