File size: 4,169 Bytes
8f7ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque

# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"

# These must match the values used during your training
NUM_FRAMES = 16
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224

# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
    processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
    model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
    print(f"Error loading model from Hugging Face Hub: {e}")
    # Handle error - exit or raise exception for Space to fail gracefully
    exit()

model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")

# Initialize a global buffer for frames for the session
# Use a deque for efficient appending/popping from both ends
frame_buffer = deque(maxlen=NUM_FRAMES)
last_inference_time = time.time()
inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
current_prediction_text = "Buffering frames..."

def predict_video_frame(image_np_array):
    global frame_buffer, last_inference_time, current_prediction_text

    # Gradio sends frames as numpy arrays (RGB)
    pil_image = Image.fromarray(image_np_array)
    frame_buffer.append(pil_image)

    current_time = time.time()

    # Only perform inference if we have enough frames and it's time for a prediction
    if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
        last_inference_time = current_time

        # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
        processed_input = processor(images=list(frame_buffer), return_tensors="pt")
        pixel_values = processed_input.pixel_values.to(device)

        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits

        predicted_class_id = logits.argmax(-1).item()
        predicted_label = model.config.id2label[predicted_class_id]
        confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()

        current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
        print(current_prediction_text) # Print to Space logs
    
    # Return the current prediction text for display in the UI
    return current_prediction_text

# --- Gradio Interface ---
# Create a streaming input for webcam
webcam_input = gr.Image(
    sources=["webcam"], # Allows webcam input
    streaming=True,      # Enables continuous streaming of frames
    shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution
    label="Live Webcam Feed"
)

# Output text box for predictions
prediction_output = gr.Textbox(label="Real-time Prediction")


# Define the Gradio Interface
# We use Blocks for more control over layout if needed, but Interface works too.
# For simplicity, we'll stick to a basic Interface
# For streaming, gr.Interface.load() is more common, but let's define from scratch.

demo = gr.Interface(
    fn=predict_video_frame,
    inputs=webcam_input,
    outputs=prediction_output,
    live=True, # Enable live updates
    allow_flagging="never", # Disable flagging on public demo
    title="TimesFormer Crime Detection Live Demo",
    description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
    # You might want to add examples for file uploads if you also want to support video files.
    # examples=["path/to/your/test_video.mp4"] # If you add video upload input
)

if __name__ == "__main__":
    demo.launch()