owinymarvin's picture
latest changes
89ce7bf
raw
history blame
8.79 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
import base64
import io
# --- Configuration ---
# CHANGED: Using a public Facebook TimesFormer model fine-tuned on Kinetics
HF_MODEL_REPO_ID = "facebook/timesformer-base-finetuned-kinetics"
MODEL_INPUT_NUM_FRAMES = 8
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
RAW_RECORDING_DURATION_SECONDS = 10.0
FRAMES_TO_SAMPLE_PER_CLIP = 20
DELAY_BETWEEN_PREDICTIONS_SECONDS = 120.0 # 2 minutes for CPU, adjust for GPU
# --- Load Model and Processor ---
print(f"Loading model and processor from {HF_MODEL_REPO_ID}...")
try:
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
print(f"Error loading model: {e}")
exit()
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}.")
print(f"Model's class labels (Kinetics): {model.config.id2label}") # Print new labels
# --- Global State Variables for Live Demo ---
raw_frames_buffer = deque()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
app_state = "recording" # States: "recording", "predicting", "processing_delay"
# --- Helper function to sample frames ---
def sample_frames(frames_list, target_count):
if not frames_list:
return []
if len(frames_list) <= target_count:
return frames_list
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
sampled = [frames_list[int(i)] for i in indices]
return sampled
# --- Main processing function for Live Demo Stream ---
def live_predict_stream(image_np_array):
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
current_time = time.time()
pil_image = Image.fromarray(image_np_array)
if app_state == "recording":
raw_frames_buffer.append(pil_image)
elapsed_recording_time = current_time - current_clip_start_time
yield f"Recording: {elapsed_recording_time:.1f}/{RAW_RECORDING_DURATION_SECONDS}s. Raw frames: {len(raw_frames_buffer)}", "Buffering..."
if elapsed_recording_time >= RAW_RECORDING_DURATION_SECONDS:
# Transition to predicting state
app_state = "predicting"
yield "Preparing to predict...", "Processing..."
print("DEBUG: Transitioning to 'predicting' state.")
elif app_state == "predicting":
# Ensure this prediction block only runs once per cycle
if raw_frames_buffer: # Only proceed if there are frames to process
print("DEBUG: Starting prediction.")
try:
sampled_raw_frames = sample_frames(list(raw_frames_buffer), FRAMES_TO_SAMPLE_PER_CLIP)
frames_for_model = sample_frames(sampled_raw_frames, MODEL_INPUT_NUM_FRAMES)
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
yield "Error during frame sampling.", f"Error: Not enough frames ({len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}). Resetting."
print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}. Resetting state.")
app_state = "recording" # Reset state to start a new recording
raw_frames_buffer.clear()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
return # Exit this stream call to wait for next frame or reset
processed_input = processor(images=frames_for_model, return_tensors="pt")
pixel_values = processed_input.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
status_message = "Prediction complete."
print(f"DEBUG: Prediction Result: {prediction_result}")
# Yield the prediction result immediately to ensure UI update
yield status_message, prediction_result
# Clear buffer and transition to delay AFTER yielding the prediction
raw_frames_buffer.clear()
last_prediction_completion_time = current_time
app_state = "processing_delay"
print("DEBUG: Transitioning to 'processing_delay' state.")
except Exception as e:
error_message = f"Error during prediction: {e}"
print(f"ERROR during prediction: {e}")
# Yield error to UI
yield "Prediction error.", error_message
app_state = "processing_delay" # Still go to delay state to prevent constant errors
raw_frames_buffer.clear() # Clear buffer to prevent re-processing same problematic frames
elif app_state == "processing_delay":
elapsed_delay = current_time - last_prediction_completion_time
if elapsed_delay < DELAY_BETWEEN_PREDICTIONS_SECONDS:
# Continue yielding the delay message and the last prediction result
yield f"Delaying next prediction: {int(elapsed_delay)}/{int(DELAY_BETWEEN_PREDICTIONS_SECONDS)}s", gr.NO_VALUE
else:
# Delay is over, reset for new recording cycle
app_state = "recording"
current_clip_start_time = current_time
print("DEBUG: Transitioning back to 'recording' state.")
yield "Starting new recording...", "Ready for new prediction."
pass
def reset_app_state_manual():
global raw_frames_buffer, current_clip_start_time, last_prediction_completion_time, app_state
raw_frames_buffer.clear()
current_clip_start_time = time.time()
last_prediction_completion_time = time.time()
app_state = "recording"
print("DEBUG: Manual reset triggered.")
# Return initial values immediately upon reset
return "Ready to record...", "Ready for new prediction."
# --- Gradio UI Layout ---
with gr.Blocks() as demo:
gr.Markdown(
f"""
# TimesFormer Action Recognition - Using Facebook Kinetics Model
This Space hosts the `{HF_MODEL_REPO_ID}` model.
Live webcam demo with recording and prediction phases.
**NOTE: This model predicts general human actions (e.g., 'playing guitar', 'walking'), not crime events.**
"""
)
with gr.Tab("Live Webcam Demo"):
gr.Markdown(
f"""
Continuously captures live webcam feed for **{RAW_RECORDING_DURATION_SECONDS} seconds**,
then makes a prediction. There is a **{DELAY_BETWEEN_PREDICTIONS_SECONDS/60:.0f} minute delay** afterwards.
"""
)
with gr.Row():
with gr.Column():
webcam_input = gr.Image(
sources=["webcam"],
streaming=True,
label="Live Webcam Feed"
)
status_output = gr.Textbox(label="Current Status", value="Initializing...")
reset_button = gr.Button("Reset / Start New Cycle")
with gr.Column():
prediction_output = gr.Textbox(label="Prediction Result", value="Waiting...")
webcam_input.stream(
live_predict_stream,
inputs=[webcam_input],
outputs=[status_output, prediction_output]
)
reset_button.click(
reset_app_state_manual,
inputs=[],
outputs=[status_output, prediction_output]
)
with gr.Tab("API Endpoint for External Clients"):
gr.Markdown(
"""
Use this API endpoint to send base64-encoded frames for prediction.
(Currently uses the Kinetics model).
"""
)
gr.Interface(
fn=lambda frames_list: "API endpoint is active for programmatic calls. See documentation in app.py.",
inputs=gr.Json(label="List of Base64-encoded image strings"),
outputs=gr.Textbox(label="API Response"),
live=False,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()