Spaces:
Running
Running
File size: 5,293 Bytes
8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f 9d0ee1c 8f7ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
# These must match the values used during your training
NUM_FRAMES = 16 # Still expecting 16 frames for a batch
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
print(f"Error loading model from Hugging Face Hub: {e}")
# Handle error - exit or raise exception for Space to fail gracefully
exit()
model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")
# Initialize a global buffer for frames that the webcam continuously captures
# This buffer will hold the *latest* NUM_FRAMES.
# We use a global variable to persist state across Gradio calls.
captured_frames_buffer = deque(maxlen=NUM_FRAMES)
# This flag will control the 5-minute wait (if still needed for testing)
wait_duration_seconds = 300 # 5 minutes
# --- Function to continuously capture frames (without immediate processing) ---
def capture_frame_into_buffer(image_np_array):
global captured_frames_buffer
# Convert Gradio's numpy array (RGB) to PIL Image
pil_image = Image.fromarray(image_np_array)
captured_frames_buffer.append(pil_image)
# Return a message showing how many frames are buffered
return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}"
# --- Function to trigger prediction with the buffered frames ---
def make_prediction_from_buffer():
global captured_frames_buffer
if len(captured_frames_buffer) < NUM_FRAMES:
return "Not enough frames buffered yet. Please capture more frames."
# Take a snapshot of the current frames in the buffer for prediction
# Convert deque to a list for the processor
frames_for_prediction = list(captured_frames_buffer)
# --- Perform Inference ---
print(f"Triggered inference on {len(frames_for_prediction)} frames...")
processed_input = processor(images=frames_for_prediction, return_tensors="pt")
pixel_values = processed_input.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_id]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
print(prediction_text) # Print to Space logs
# Clear the buffer after prediction if you want to capture a *new* set of frames for the next click
# captured_frames_buffer.clear()
# If you *don't* clear, the next click will re-predict on the same last 16 frames.
# --- Introduce the artificial 5-minute wait (if still desired) ---
# This will pause the *return* from this function, effectively blocking the UI update
# If you remove this, the prediction will show immediately.
# print(f"Initiating {wait_duration_seconds} second wait...")
# time.sleep(wait_duration_seconds)
# print("Wait finished.")
return prediction_text
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown(
f"""
# TimesFormer Crime Detection Live Demo (Manual Trigger)
This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed.
It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**.
The model requires **{NUM_FRAMES} frames** for a prediction.
Please allow webcam access.
"""
)
with gr.Row():
with gr.Column():
webcam_input = gr.Image(
sources=["webcam"],
streaming=True,
label="Live Webcam Feed"
)
# This textbox will show the buffering status dynamically
buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}")
# Button to trigger prediction
predict_button = gr.Button("Predict Latest Frames")
with gr.Column():
prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.")
# Define actions
# This continuously updates the buffer_status as frames come in
webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status])
# This triggers the prediction when the button is clicked
predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output])
if __name__ == "__main__":
demo.launch() |