owinymarvin's picture
latest changes
89b0c64
raw
history blame
8.09 kB
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import time # For potential sleep to control frame rate if needed
# --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
class SmallVideoClassifier(torch.nn.Module):
def __init__(self, num_classes=2, num_frames=8):
super(SmallVideoClassifier, self).__init__()
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
try:
weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
except Exception:
print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
weights = None
self.feature_extractor = mobilenet_v3_small(weights=weights)
self.feature_extractor.classifier = torch.nn.Identity()
self.num_spatial_features = 576
self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.num_spatial_features, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(512, num_classes)
)
def forward(self, pixel_values):
batch_size, num_frames, channels, height, width = pixel_values.shape
x = pixel_values.view(batch_size * num_frames, channels, height, width)
spatial_features = self.feature_extractor(x)
spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
logits = self.classifier(temporal_features)
return logits
# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
model_config = json.load(f)
NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)
CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
device = torch.device("cpu") # Explicitly use CPU
print(f"Using device: {device}")
model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval()
print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- 4. Gradio Live Inference Function (Generator) ---
# This function will receive individual frames from the webcam
def predict_live_frames(input_frame):
global frame_buffer, current_prediction_label, current_probabilities # Use global to maintain state across calls
if input_frame is None:
# If no frame is received (e.g., webcam not active), yield a black frame or handle gracefully
dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
yield dummy_frame
return # Exit if no frame to process
# Gradio Webcam gives NumPy array (H, W, C) in RGB
pil_image = Image.fromarray(input_frame)
# Apply transformations (outputs C, H, W tensor)
processed_frame_tensor = transform(pil_image)
frame_buffer.append(processed_frame_tensor)
# Perform prediction only when the buffer is full
if len(frame_buffer) == NUM_FRAMES:
# Stack the buffered frames and add a batch dimension
input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
# --- Sliding Window ---
# Keep the last few frames to allow continuous predictions
# For example, if NUM_FRAMES is 8, and we want a new prediction every 2 frames,
# we slide the window by 2:
slide_window_by = 1 # Predict every frame (most "real-time" feel but highest compute)
# Or: NUM_FRAMES // 2 (e.g., predict every 4 frames for NUM_FRAMES=8)
# Or: NUM_FRAMES (non-overlapping windows, less frequent updates)
frame_buffer = frame_buffer[slide_window_by:]
# --- Draw Prediction on the current input frame ---
# Convert the input_frame (RGB NumPy array) to BGR for OpenCV drawing
display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
# Draw the main prediction label
text_color = (0, 255, 0) # Green (BGR)
text_outline_color = (0, 0, 0) # Black
font_scale = 1.0
font_thickness = 2
# Draw outline first for better readability
cv2.putText(display_frame, current_prediction_label, (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
# Draw actual text
cv2.putText(display_frame, current_prediction_label, (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
# Draw probabilities for all classes (like YOLO)
y_offset = 80 # Start drawing probabilities slightly lower
for label, prob in current_probabilities.items():
prob_text = f"{label}: {prob:.2f}"
cv2.putText(display_frame, prob_text, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
cv2.putText(display_frame, prob_text, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA) # Yellow for probs
y_offset += 30 # Move down for next probability
# Yield the processed frame back to Gradio for display
# Gradio expects RGB NumPy array for video/image components
yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
# --- Initialize global state for the generator function ---
frame_buffer = [] # Buffer for collecting frames for model input
current_prediction_label = "Initializing..."
current_probabilities = {label: 0.0 for label in CLASS_LABELS} # Initial probabilities
# --- 5. Gradio Interface Setup ---
iface = gr.Interface(
fn=predict_live_frames,
# Use gr.Webcam for direct webcam input
inputs=gr.Webcam(streaming=True, label="Live Webcam Feed for Violence Detection"),
# Outputs are updated continuously by the generator
outputs=gr.Image(type="numpy", label="Live Prediction Output"), # Using Image as output for continuous frames
title="Real-time Violence Detection with SmallVideoClassifier (Webcam)",
description=(
"This model detects violence in a live webcam feed. "
"Predictions (Class and Probabilities) will be displayed on each frame. "
"Please allow webcam access when prompted."
),
allow_flagging="never", # Disable flagging on Hugging Face Spaces
)
iface.launch()