import gradio as gr import torch from transformers import AutoImageProcessor, TimesformerForVideoClassification import cv2 from PIL import Image import numpy as np import time from collections import deque import base64 import io # --- Configuration --- HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection" MODEL_INPUT_NUM_FRAMES = 8 TARGET_IMAGE_HEIGHT = 224 TARGET_IMAGE_WIDTH = 224 RAW_RECORDING_DURATION_SECONDS = 10.0 FRAMES_TO_SAMPLE_PER_CLIP = 20 DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # 2 minutes for CPU # --- Load Model and Processor --- print(f"Loading model and processor from {HF_MODEL_REPO_ID}...") try: processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) except Exception as e: print(f"Error loading model: {e}") exit() model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f"Model loaded on {device}.") # --- Global State Variables for Live Demo --- raw_frames_buffer = deque() current_clip_start_time = time.time() last_prediction_completion_time = time.time() app_state = "recording" # States: "recording", "predicting", "processing_delay" # --- Helper function to sample frames --- def sample_frames(frames_list, target_count): if not frames_list: return [] if len(frames_list) <= target_count: return frames_list indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int) sampled = [frames_list[int(i)] for i in indices] return sampled # --- Main processing function for Live Demo Stream --- def live_predict_stream(image_np_array): global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state current_time = time.time() pil_image = Image.fromarray(image_np_array) if app_state == "recording": raw_frames_buffer.append(pil_image) elapsed_recording_time = current_time - current_clip_start_time yield f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}", "Buffering..." if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS: # Transition to predicting state app_state = "predicting" yield "Preparing to predict...", "Processing..." print("DEBUG: Transitioning to 'predicting' state.") elif app_state == "predicting": # Ensure this prediction block only runs once per cycle if raw_frames_buffer: # Only proceed if there are frames to process print("DEBUG: Starting prediction.") try: sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP) frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES) if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES: yield "Error during frame sampling.", f"Error: Not enough frames ({len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}). Resetting." print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}. Resetting state.") app_state = "recording" # Reset state to start a new recording raw_frames_buffer.clear() current_clip_start_time = time.time() last_prediction_completion_time = time.time() return # Exit this stream call to wait for next frame or reset processed_input = processor(images=frames_for_model, return_tensors="pt") pixel_values = processed_input.pixel_values.to(device) with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits predicted_class_id = logits.argmax(-1).item() predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" status_message = "Prediction complete." print(f"DEBUG: Prediction Result: {prediction_result}") # Yield the prediction result immediately to ensure UI update yield status_message, prediction_result # Clear buffer and transition to delay AFTER yielding the prediction raw_frames_buffer.clear() last_prediction_completion_time = current_time app_state = "processing_delay" print("DEBUG: Transitioning to 'processing_delay' state.") except Exception as e: error_message = f"Error during prediction: {e}" print(f"ERROR during prediction: {e}") # Yield error to UI yield "Prediction error.", error_message app_state = "processing_delay" # Still go to delay state to prevent constant errors raw_frames_buffer.clear() # Clear buffer to prevent re-processing same problematic frames elif app_state == "processing_delay": elapsed_delay = current_time - last_prediction_completion_time if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS: # Continue yielding the delay message and the last prediction result # Assuming prediction_result from previous state is still held by UI yield f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s", gr.NO_VALUE # NO_VALUE keeps previous prediction visible else: # Delay is over, reset for new recording cycle app_state = "recording" current_clip_start_time = current_time print("DEBUG: Transitioning back to 'recording' state.") yield "Starting new recording...", "Ready for new prediction." # If for some reason nothing is yielded, return the current state to prevent UI freeze. # This acts as a fallback if no state transition happens. # However, with the yield statements, this might be less critical. # For streaming, yielding is the preferred way to update. # If the function ends without yielding, Gradio will just keep the last state. # We always yield in every branch. pass # No explicit return needed at the end if all paths yield def reset_app_state_manual(): global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state raw_frames_buffer.clear() current_clip_start_time = time.time() last_prediction_completion_time = time.time() app_state = "recording" print("DEBUG: Manual reset triggered.") # Return initial values immediately upon reset return "Ready to record...", "Ready for new prediction." # --- Gradio UI Layout --- with gr.Blocks() as demo: gr.Markdown( f""" # TimesFormer Crime Detection - Hugging Face Space Host This Space hosts the `owinymarvin/timesformer-crime-detection` model. Live webcam demo with recording and prediction phases. """ ) with gr.Tab("Live Webcam Demo"): gr.Markdown( f""" Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**, then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards. """ ) with gr.Row(): with gr.Column(): webcam_input = gr.Image( sources=["webcam"], streaming=True, label="Live Webcam Feed" ) status_output = gr.Textbox(label="Current Status", value="Initializing...") reset_button = gr.Button("Reset / Start New Cycle") with gr.Column(): prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...") # IMPORTANT: Use webcam_input.stream() with a generator function (live_predict_stream) # to enable progressive updates via 'yield'. webcam_input.stream( live_predict_stream, inputs=[webcam_input], outputs=[status_output, prediction_output] ) # The reset button is a regular click event, not a stream reset_button.click( reset_app_state_manual, inputs=[], outputs=[status_output, prediction_output] ) with gr.Tab("API Endpoint for External Clients"): gr.Markdown( """ Use this API endpoint to send base64-encoded frames for prediction. """ ) # Placeholder for the API tab. The actual API calls target /run/predict_from_frames_api gr.Interface( fn=lambda frames_list: "API endpoint is active for programmatic calls. See documentation in app.py.", inputs=gr.Json(label="List of Base64-encoded image strings"), outputs=gr.Textbox(label="API Response"), live=False, allow_flagging="never" ) if __name__ == "__main__": demo.launch()