Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, TimesformerForVideoClassification | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import time | |
from collections import deque | |
import base64 | |
import io | |
# --- Configuration --- | |
# CHANGED: Using a public Facebook TimesFormer model fine-tuned on Kinetics | |
HF_MODEL_REPO_ID = "facebook/timesformer-base-finetuned-kinetics" | |
MODEL_INPUT_NUM_FRAMES = 8 | |
TARGET_IMAGE_HEIGHT = 224 | |
TARGET_IMAGE_WIDTH = 224 | |
RAW_RECORDING_DURATION_SECONDS = 10.0 | |
FRAMES_TO_SAMPLE_PER_CLIP = 20 | |
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # 2 minutes for CPU, adjust for GPU | |
# --- Load Model and Processor --- | |
print(f"Loading model and processor from {HF_MODEL_REPO_ID}...") | |
try: | |
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID) | |
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
exit() | |
model.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"Model loaded on {device}.") | |
print(f"Model's class labels (Kinetics): {model.config.id2label}") # Print new labels | |
# --- Global State Variables for Live Demo --- | |
raw_frames_buffer = deque() | |
current_clip_start_time = time.time() | |
last_prediction_completion_time = time.time() | |
app_state = "recording" # States: "recording", "predicting", "processing_delay" | |
# --- Helper function to sample frames --- | |
def sample_frames(frames_list, target_count): | |
if not frames_list: | |
return [] | |
if len(frames_list) <= target_count: | |
return frames_list | |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int) | |
sampled = [frames_list[int(i)] for i in indices] | |
return sampled | |
# --- Main processing function for Live Demo Stream --- | |
def live_predict_stream(image_np_array): | |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
current_time = time.time() | |
pil_image = Image.fromarray(image_np_array) | |
if app_state == "recording": | |
raw_frames_buffer.append(pil_image) | |
elapsed_recording_time = current_time - current_clip_start_time | |
yield f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}", "Buffering..." | |
if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS: | |
# Transition to predicting state | |
app_state = "predicting" | |
yield "Preparing to predict...", "Processing..." | |
print("DEBUG: Transitioning to 'predicting' state.") | |
elif app_state == "predicting": | |
# Ensure this prediction block only runs once per cycle | |
if raw_frames_buffer: # Only proceed if there are frames to process | |
print("DEBUG: Starting prediction.") | |
try: | |
sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP) | |
frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES) | |
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES: | |
yield "Error during frame sampling.", f"Error: Not enough frames ({len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}). Resetting." | |
print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}. Resetting state.") | |
app_state = "recording" # Reset state to start a new recording | |
raw_frames_buffer.clear() | |
current_clip_start_time = time.time() | |
last_prediction_completion_time = time.time() | |
return # Exit this stream call to wait for next frame or reset | |
processed_input = processor(images=frames_for_model, return_tensors="pt") | |
pixel_values = processed_input.pixel_values.to(device) | |
with torch.no_grad(): | |
outputs = model(pixel_values) | |
logits = outputs.logits | |
predicted_class_id = logits.argmax(-1).item() | |
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") | |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item() | |
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})" | |
status_message = "Prediction complete." | |
print(f"DEBUG: Prediction Result: {prediction_result}") | |
# Yield the prediction result immediately to ensure UI update | |
yield status_message, prediction_result | |
# Clear buffer and transition to delay AFTER yielding the prediction | |
raw_frames_buffer.clear() | |
last_prediction_completion_time = current_time | |
app_state = "processing_delay" | |
print("DEBUG: Transitioning to 'processing_delay' state.") | |
except Exception as e: | |
error_message = f"Error during prediction: {e}" | |
print(f"ERROR during prediction: {e}") | |
# Yield error to UI | |
yield "Prediction error.", error_message | |
app_state = "processing_delay" # Still go to delay state to prevent constant errors | |
raw_frames_buffer.clear() # Clear buffer to prevent re-processing same problematic frames | |
elif app_state == "processing_delay": | |
elapsed_delay = current_time - last_prediction_completion_time | |
if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS: | |
# Continue yielding the delay message and the last prediction result | |
yield f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s", gr.NO_VALUE | |
else: | |
# Delay is over, reset for new recording cycle | |
app_state = "recording" | |
current_clip_start_time = current_time | |
print("DEBUG: Transitioning back to 'recording' state.") | |
yield "Starting new recording...", "Ready for new prediction." | |
pass | |
def reset_app_state_manual(): | |
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state | |
raw_frames_buffer.clear() | |
current_clip_start_time = time.time() | |
last_prediction_completion_time = time.time() | |
app_state = "recording" | |
print("DEBUG: Manual reset triggered.") | |
# Return initial values immediately upon reset | |
return "Ready to record...", "Ready for new prediction." | |
# --- Gradio UI Layout --- | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
f""" | |
# TimesFormer Action Recognition - Using Facebook Kinetics Model | |
This Space hosts the `{HF_MODEL_REPO_ID}` model. | |
Live webcam demo with recording and prediction phases. | |
**NOTE: This model predicts general human actions (e.g., 'playing guitar', 'walking'), not crime events.** | |
""" | |
) | |
with gr.Tab("Live Webcam Demo"): | |
gr.Markdown( | |
f""" | |
Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**, | |
then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
webcam_input = gr.Image( | |
sources=["webcam"], | |
streaming=True, | |
label="Live Webcam Feed" | |
) | |
status_output = gr.Textbox(label="Current Status", value="Initializing...") | |
reset_button = gr.Button("Reset / Start New Cycle") | |
with gr.Column(): | |
prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...") | |
webcam_input.stream( | |
live_predict_stream, | |
inputs=[webcam_input], | |
outputs=[status_output, prediction_output] | |
) | |
reset_button.click( | |
reset_app_state_manual, | |
inputs=[], | |
outputs=[status_output, prediction_output] | |
) | |
with gr.Tab("API Endpoint for External Clients"): | |
gr.Markdown( | |
""" | |
Use this API endpoint to send base64-encoded frames for prediction. | |
(Currently uses the Kinetics model). | |
""" | |
) | |
gr.Interface( | |
fn=lambda frames_list: "API endpoint is active for programmatic calls. See documentation in app.py.", | |
inputs=gr.Json(label="List of Base64-encoded image strings"), | |
outputs=gr.Textbox(label="API Response"), | |
live=False, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() |