File size: 8,225 Bytes
8f7ac8f
 
 
 
 
 
 
 
4873e8b
 
8f7ac8f
 
4873e8b
8f7ac8f
 
4873e8b
 
 
8f7ac8f
4873e8b
8f7ac8f
 
 
 
4873e8b
8f7ac8f
 
4873e8b
8f7ac8f
 
4873e8b
8f7ac8f
4873e8b
 
 
4784ef2
 
 
 
 
 
 
 
c450b97
 
4784ef2
 
 
 
9d0ee1c
6f472e5
4873e8b
9d0ee1c
4784ef2
 
8f7ac8f
4784ef2
 
 
4873e8b
 
 
4784ef2
4873e8b
 
 
 
 
 
 
 
 
 
 
 
 
 
c450b97
 
4873e8b
 
 
 
 
 
 
 
 
 
 
 
 
c450b97
4873e8b
 
 
 
 
 
4784ef2
4873e8b
 
 
 
 
 
 
 
 
6f472e5
4873e8b
 
4784ef2
 
 
4873e8b
 
4784ef2
4873e8b
 
 
 
4784ef2
 
 
 
 
 
 
 
4873e8b
 
 
8f7ac8f
9d0ee1c
 
 
4873e8b
 
 
9d0ee1c
 
4873e8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c450b97
 
4873e8b
c450b97
 
 
 
 
4873e8b
c450b97
 
 
 
 
8f7ac8f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
import base64
import io

HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
MODEL_INPUT_NUM_FRAMES = 8
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
RAW_RECORDING_DURATION_SECONDS = 10.0
FRAMES_TO_SAMPLE_PER_CLIP = 20
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0

print(f"Loading model and processor from {HF_MODEL_REPO_ID}...")
try:
    processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
    model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
    print(f"Error loading model: {e}")
    exit()

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}.")

raw_frames_buffer = deque()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
app_state = "recording"

def sample_frames(frames_list, target_count):
    if not frames_list:
        return []
    if len(frames_list) <= target_count:
        return frames_list
    indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
    # FIX: Corrected list indexing from () to []
    sampled = [frames_list[int(i)] for i in indices]
    return sampled

def live_predict_stream(image_np_array):
    global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state

    current_time = time.time()
    pil_image = Image.fromarray(image_np_array)

    status_message = ""
    prediction_result = ""

    if app_state == "recording":
        raw_frames_buffer.append(pil_image)
        elapsed_recording_time = current_time - current_clip_start_time
        status_message = f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}"
        prediction_result = "Buffering..."
        if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS:
            app_state = "predicting"
            status_message = "Preparing to predict..."
            prediction_result = "Processing..."
            print("DEBUG: Transitioning to 'predicting' state.")

    elif app_state == "predicting":
        if raw_frames_buffer:
            print("DEBUG: Starting prediction.")
            try:
                sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
                frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)

                if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
                    prediction_result = "Error: Not enough frames for model."
                    status_message = "Error during frame sampling."
                    print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}")
                    app_state = "recording" # Reset to recording state
                    raw_frames_buffer.clear()
                    current_clip_start_time = time.time()
                    last_prediction_completion_time = time.time()
                    return status_message, prediction_result

                processed_input = processor(images=frames_for_model, return_tensors="pt")
                pixel_values = processed_input.pixel_values.to(device)

                with torch.no_grad():
                    outputs = model(pixel_values)
                    logits = outputs.logits

                predicted_class_id = logits.argmax(-1).item()
                predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
                confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()

                prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
                status_message = "Prediction complete."
                print(f"DEBUG: Prediction Result: {prediction_result}")

                raw_frames_buffer.clear()
                last_prediction_completion_time = current_time
                app_state = "processing_delay"
                print("DEBUG: Transitioning to 'processing_delay' state.")

            except Exception as e:
                prediction_result = f"Error during prediction: {e}"
                status_message = "Prediction error."
                print(f"ERROR during prediction: {e}")
                app_state = "processing_delay" # Move to delay to avoid continuous errors
        else:
            status_message = "Waiting for frames..."
            prediction_result = "..."

    elif app_state == "processing_delay":
        elapsed_delay = current_time - last_prediction_completion_time
        status_message = f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s"
        if elapsed_delay >= DELAY_BETWEEN_PREDICTIONS_SECONDS:
            app_state = "recording"
            current_clip_start_time = current_time
            status_message = "Starting new recording..."
            prediction_result = "Ready..."
            print("DEBUG: Transitioning back to 'recording' state.")

    return status_message, prediction_result

def reset_app_state_manual():
    global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
    raw_frames_buffer.clear()
    current_clip_start_time = time.time()
    last_prediction_completion_time = time.time()
    app_state = "recording"
    print("DEBUG: Manual reset triggered.")
    return "Ready to record...", "Ready for new prediction."

with gr.Blocks() as demo:
    gr.Markdown(
        f"""
        # TimesFormer Crime Detection - Hugging Face Space Host
        This Space hosts the `owinymarvin/timesformer-crime-detection` model.
        Live webcam demo with recording and prediction phases.
        """
    )

    with gr.Tab("Live Webcam Demo"):
        gr.Markdown(
            f"""
            Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**,
            then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards.
            """
        )
        with gr.Row():
            with gr.Column():
                webcam_input = gr.Image(
                    sources=["webcam"],
                    streaming=True,
                    label="Live Webcam Feed"
                )
                status_output = gr.Textbox(label="Current Status", value="Initializing...")
                reset_button = gr.Button("Reset / Start New Cycle")
            with gr.Column():
                prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...")

        webcam_input.stream(
            live_predict_stream,
            inputs=[webcam_input],
            outputs=[status_output, prediction_output]
        )
        reset_button.click(
            reset_app_state_manual,
            inputs=[],
            outputs=[status_output, prediction_output]
        )

    with gr.Tab("API Endpoint for External Clients"):
        gr.Markdown(
            """
            Use this API endpoint to send base64-encoded frames for prediction.
            """
        )
        # Re-adding a slightly more representative API interface
        # Gradio's automatic API documentation will use this to show inputs/outputs
        gr.Interface(
            fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.",
            inputs=gr.Json(label="List of Base64-encoded image strings"),
            outputs=gr.Textbox(label="API Response"),
            live=False,
            allow_flagging="never" # For API endpoints, flagging is usually not desired
        )
        # Note: The actual `predict_from_frames_api` function is defined above,
        # but for a clean API tab, we can use a dummy interface here that Gradio will
        # use to generate the interactive API documentation. The actual API call
        # from your local script directly targets the /run/predict_from_frames_api endpoint.


if __name__ == "__main__":
    demo.launch()