File size: 5,293 Bytes
8f7ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
9d0ee1c
8f7ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d0ee1c
 
 
 
8f7ac8f
9d0ee1c
 
8f7ac8f
9d0ee1c
 
 
 
 
8f7ac8f
9d0ee1c
 
 
 
 
8f7ac8f
9d0ee1c
 
 
8f7ac8f
9d0ee1c
 
8f7ac8f
9d0ee1c
 
 
8f7ac8f
9d0ee1c
 
 
 
8f7ac8f
9d0ee1c
 
 
8f7ac8f
9d0ee1c
 
 
 
 
 
 
 
 
 
8f7ac8f
9d0ee1c
 
 
 
 
 
 
 
 
8f7ac8f
 
9d0ee1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7ac8f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque

# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"

# These must match the values used during your training
NUM_FRAMES = 16 # Still expecting 16 frames for a batch
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224

# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
    processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
    model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
    print(f"Error loading model from Hugging Face Hub: {e}")
    # Handle error - exit or raise exception for Space to fail gracefully
    exit()

model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")

# Initialize a global buffer for frames that the webcam continuously captures
# This buffer will hold the *latest* NUM_FRAMES.
# We use a global variable to persist state across Gradio calls.
captured_frames_buffer = deque(maxlen=NUM_FRAMES)

# This flag will control the 5-minute wait (if still needed for testing)
wait_duration_seconds = 300 # 5 minutes

# --- Function to continuously capture frames (without immediate processing) ---
def capture_frame_into_buffer(image_np_array):
    global captured_frames_buffer

    # Convert Gradio's numpy array (RGB) to PIL Image
    pil_image = Image.fromarray(image_np_array)
    captured_frames_buffer.append(pil_image)

    # Return a message showing how many frames are buffered
    return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}"


# --- Function to trigger prediction with the buffered frames ---
def make_prediction_from_buffer():
    global captured_frames_buffer

    if len(captured_frames_buffer) < NUM_FRAMES:
        return "Not enough frames buffered yet. Please capture more frames."

    # Take a snapshot of the current frames in the buffer for prediction
    # Convert deque to a list for the processor
    frames_for_prediction = list(captured_frames_buffer)

    # --- Perform Inference ---
    print(f"Triggered inference on {len(frames_for_prediction)} frames...")
    processed_input = processor(images=frames_for_prediction, return_tensors="pt")
    pixel_values = processed_input.pixel_values.to(device)

    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits

    predicted_class_id = logits.argmax(-1).item()
    predicted_label = model.config.id2label[predicted_class_id]
    confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()

    prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
    print(prediction_text) # Print to Space logs
    
    # Clear the buffer after prediction if you want to capture a *new* set of frames for the next click
    # captured_frames_buffer.clear() 
    # If you *don't* clear, the next click will re-predict on the same last 16 frames.
    
    # --- Introduce the artificial 5-minute wait (if still desired) ---
    # This will pause the *return* from this function, effectively blocking the UI update
    # If you remove this, the prediction will show immediately.
    # print(f"Initiating {wait_duration_seconds} second wait...")
    # time.sleep(wait_duration_seconds) 
    # print("Wait finished.")

    return prediction_text


# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown(
        f"""
        # TimesFormer Crime Detection Live Demo (Manual Trigger)
        This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed.
        It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**.
        The model requires **{NUM_FRAMES} frames** for a prediction.
        Please allow webcam access.
        """
    )
    with gr.Row():
        with gr.Column():
            webcam_input = gr.Image(
                sources=["webcam"],
                streaming=True,
                label="Live Webcam Feed"
            )
            # This textbox will show the buffering status dynamically
            buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}")
            
            # Button to trigger prediction
            predict_button = gr.Button("Predict Latest Frames")
            
        with gr.Column():
            prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.")
            
    # Define actions
    # This continuously updates the buffer_status as frames come in
    webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status])
    
    # This triggers the prediction when the button is clicked
    predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output])


if __name__ == "__main__":
    demo.launch()