File size: 4,075 Bytes
9d0ee1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque

# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"

# These must match the values used during your training
NUM_FRAMES = 8
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224

# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
    processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
    model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
    print(f"Error loading model from Hugging Face Hub: {e}")
    # Handle error - exit or raise exception for Space to fail gracefully
    exit()

model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")

# Initialize a global buffer for frames for the session
# Use a deque for efficient appending/popping from both ends
frame_buffer = deque(maxlen=NUM_FRAMES)
last_inference_time = time.time()
inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
current_prediction_text = "Buffering frames..." # Initialize global text

def predict_video_frame(image_np_array):
    global frame_buffer, last_inference_time, current_prediction_text

    # Gradio sends frames as numpy arrays (RGB)
    # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
    pil_image = Image.fromarray(image_np_array)
    frame_buffer.append(pil_image)

    current_time = time.time()

    # Only perform inference if we have enough frames and it's time for a prediction
    if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
        last_inference_time = current_time

        # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
        # It will handle resizing and normalization based on its config
        processed_input = processor(images=list(frame_buffer), return_tensors="pt")
        pixel_values = processed_input.pixel_values.to(device)

        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits

        predicted_class_id = logits.argmax(-1).item()
        predicted_label = model.config.id2label[predicted_class_id]
        confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()

        current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
        print(current_prediction_text) # Print to Space logs
    
    # Return the current prediction text for display in the UI
    # Gradio's streaming will update this textbox asynchronously
    return current_prediction_text

# --- Gradio Interface ---
# Create a streaming input for webcam
webcam_input = gr.Image(
    sources=["webcam"], # Allows webcam input
    streaming=True,      # Enables continuous streaming of frames
    # REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError
    label="Live Webcam Feed"
)

# Output text box for predictions
prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...")


# Define the Gradio Interface
demo = gr.Interface(
    fn=predict_video_frame,
    inputs=webcam_input,
    outputs=prediction_output,
    live=True, # Enable live updates
    allow_flagging="never", # Disable flagging on public demo
    title="TimesFormer Crime Detection Live Demo",
    description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
)

if __name__ == "__main__":
    demo.launch()