SW_AI_deployment / app copy.py
owinymarvin's picture
latest changes
9d0ee1c
raw
history blame
4.08 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, TimesformerForVideoClassification
import cv2
from PIL import Image
import numpy as np
import time
from collections import deque
# --- Configuration ---
# Your Hugging Face model repository ID
HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
# These must match the values used during your training
NUM_FRAMES = 8
TARGET_IMAGE_HEIGHT = 224
TARGET_IMAGE_WIDTH = 224
# --- Load Model and Processor ---
print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
try:
processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
except Exception as e:
print(f"Error loading model from Hugging Face Hub: {e}")
# Handle error - exit or raise exception for Space to fail gracefully
exit()
model.eval() # Set model to evaluation mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded successfully on {device}.")
print(f"Model's class labels: {model.config.id2label}")
# Initialize a global buffer for frames for the session
# Use a deque for efficient appending/popping from both ends
frame_buffer = deque(maxlen=NUM_FRAMES)
last_inference_time = time.time()
inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
current_prediction_text = "Buffering frames..." # Initialize global text
def predict_video_frame(image_np_array):
global frame_buffer, last_inference_time, current_prediction_text
# Gradio sends frames as numpy arrays (RGB)
# The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
pil_image = Image.fromarray(image_np_array)
frame_buffer.append(pil_image)
current_time = time.time()
# Only perform inference if we have enough frames and it's time for a prediction
if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
last_inference_time = current_time
# Preprocess the frames. processor expects a list of PIL Images or numpy arrays
# It will handle resizing and normalization based on its config
processed_input = processor(images=list(frame_buffer), return_tensors="pt")
pixel_values = processed_input.pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_id]
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
print(current_prediction_text) # Print to Space logs
# Return the current prediction text for display in the UI
# Gradio's streaming will update this textbox asynchronously
return current_prediction_text
# --- Gradio Interface ---
# Create a streaming input for webcam
webcam_input = gr.Image(
sources=["webcam"], # Allows webcam input
streaming=True, # Enables continuous streaming of frames
# REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError
label="Live Webcam Feed"
)
# Output text box for predictions
prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...")
# Define the Gradio Interface
demo = gr.Interface(
fn=predict_video_frame,
inputs=webcam_input,
outputs=prediction_output,
live=True, # Enable live updates
allow_flagging="never", # Disable flagging on public demo
title="TimesFormer Crime Detection Live Demo",
description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
)
if __name__ == "__main__":
demo.launch()