import gradio as gr import torch from transformers import AutoImageProcessor, TimesformerForVideoClassification import cv2 from PIL import Image import numpy as np import time from collections import deque # --- Configuration --- # Your Hugging Face model repository ID HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" # These must match the values used during your training NUM_FRAMES = 16 # Still expecting 16 frames for a batch TARGET_IMAGE_HEIGHT = 224 TARGET_IMAGE_WIDTH = 224 # --- Load Model and Processor --- print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...") try: processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) except Exception as e: print(f"Error loading model from Hugging Face Hub: {e}") # Handle error - exit or raise exception for Space to fail gracefully exit() model.eval() # Set model to evaluation mode device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f"Model loaded successfully on {device}.") print(f"Model's class labels: {model.config.id2label}") # Initialize a global buffer for frames that the webcam continuously captures # This buffer will hold the *latest* NUM_FRAMES. # We use a global variable to persist state across Gradio calls. captured_frames_buffer = deque(maxlen=NUM_FRAMES) # This flag will control the 5-minute wait (if still needed for testing) wait_duration_seconds = 300 # 5 minutes # --- Function to continuously capture frames (without immediate processing) --- def capture_frame_into_buffer(image_np_array): global captured_frames_buffer # Convert Gradio's numpy array (RGB) to PIL Image pil_image = Image.fromarray(image_np_array) captured_frames_buffer.append(pil_image) # Return a message showing how many frames are buffered return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}" # --- Function to trigger prediction with the buffered frames --- def make_prediction_from_buffer(): global captured_frames_buffer if len(captured_frames_buffer) < NUM_FRAMES: return "Not enough frames buffered yet. Please capture more frames." # Take a snapshot of the current frames in the buffer for prediction # Convert deque to a list for the processor frames_for_prediction = list(captured_frames_buffer) # --- Perform Inference --- print(f"Triggered inference on {len(frames_for_prediction)} frames...") processed_input = processor(images=frames_for_prediction, return_tensors="pt") pixel_values = processed_input.pixel_values.to(device) with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits predicted_class_id = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_id] confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})" print(prediction_text) # Print to Space logs # Clear the buffer after prediction if you want to capture a *new* set of frames for the next click # captured_frames_buffer.clear() # If you *don't* clear, the next click will re-predict on the same last 16 frames. # --- Introduce the artificial 5-minute wait (if still desired) --- # This will pause the *return* from this function, effectively blocking the UI update # If you remove this, the prediction will show immediately. # print(f"Initiating {wait_duration_seconds} second wait...") # time.sleep(wait_duration_seconds) # print("Wait finished.") return prediction_text # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown( f""" # TimesFormer Crime Detection Live Demo (Manual Trigger) This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**. The model requires **{NUM_FRAMES} frames** for a prediction. Please allow webcam access. """ ) with gr.Row(): with gr.Column(): webcam_input = gr.Image( sources=["webcam"], streaming=True, label="Live Webcam Feed" ) # This textbox will show the buffering status dynamically buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}") # Button to trigger prediction predict_button = gr.Button("Predict Latest Frames") with gr.Column(): prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.") # Define actions # This continuously updates the buffer_status as frames come in webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status]) # This triggers the prediction when the button is clicked predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output]) if __name__ == "__main__": demo.launch()