owinymarvin's picture
latest changes
429fbf1
raw
history blame
7.9 kB
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import time
# --- 1. Define Model Architecture ---
class SmallVideoClassifier(torch.nn.Module):
def __init__(self, num_classes=2, num_frames=8):
super(SmallVideoClassifier, self).__init__()
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
try:
weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
except Exception:
print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
weights = None
self.feature_extractor = mobilenet_v3_small(weights=weights)
self.feature_extractor.classifier = torch.nn.Identity()
self.num_spatial_features = 576
self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.num_spatial_features, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(512, num_classes)
)
def forward(self, pixel_values):
batch_size, num_frames, channels, height, width = pixel_values.shape
x = pixel_values.view(batch_size * num_frames, channels, height, width)
spatial_features = self.feature_extractor(x)
spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
logits = self.classifier(temporal_features)
return logits
# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
model_config = json.load(f)
NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)
CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
device = torch.device("cpu")
print(f"Using device: {device}")
model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval()
print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- Global state for the generator function ---
frame_buffer = []
current_prediction_label = "Initializing..."
current_probabilities = {label: 0.0 for label in CLASS_LABELS}
# --- 4. Gradio Live Inference Function (Generator) ---
def predict_live_frames(input_frame):
global frame_buffer, current_prediction_label, current_probabilities
if input_frame is None:
# If no frame is received (e.g., webcam not active or disconnected)
dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
yield dummy_frame
return
pil_image = Image.fromarray(input_frame)
processed_frame_tensor = transform(pil_image)
frame_buffer.append(processed_frame_tensor)
slide_window_by = 1
if len(frame_buffer) >= NUM_FRAMES:
input_tensor = torch.stack(frame_buffer[-NUM_FRAMES:], dim=0).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
frame_buffer = frame_buffer[slide_window_by:]
display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
# Draw the main prediction label
text_color = (0, 255, 0) # Green (BGR)
text_outline_color = (0, 0, 0) # Black
font_scale = 1.0
font_thickness = 2
cv2.putText(display_frame, current_prediction_label, (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
cv2.putText(display_frame, current_prediction_label, (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
# Draw probabilities for all classes
y_offset = 80
for label, prob in current_probabilities.items():
prob_text = f"{label}: {prob:.2f}"
cv2.putText(display_frame, prob_text, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
cv2.putText(display_frame, prob_text, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA)
y_offset += 30
yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
# --- 5. Gradio Blocks Interface Setup ---
with gr.Blocks(
title="Real-time Violence Detection", # Title for the browser tab
theme=gr.themes.Default(primary_hue=gr.Color(c50='#e0f7fa', c100='#b2ebf2', c200='#80deea', c300='#4dd0e1', c400='#26c6da', c500='#00bcd4', c600='#00acc1', c700='#0097a7', c800='#00838f', c900='#006064', ca50='#84ffff', ca100='#18ffff', ca200='#00e5ff', ca400='#00b8d4')) # Optional: A subtle theme change
) as demo:
# Optional: Display a title and description clearly, even without buttons
gr.Markdown(
"""
# 🎬 Real-time Violence Detection
**Live Feed with Constant Predictions**
This model analyzes your live webcam feed for violence, displaying the predicted class and probabilities on the screen.
Please grant webcam access when prompted by your browser.
"""
)
with gr.Row():
# Input: Live webcam feed
# We need to set a minimum height and width to ensure the video feed is displayed reasonably
video_input = gr.Video(
sources=["webcam"],
streaming=True,
label="Live Webcam Feed",
# Optional: Set dimensions for the video display
height=480, # or None for auto
width=640 # or None for auto
)
# Output: Image component to display processed frames
video_output = gr.Image(
type="numpy",
label="Processed Feed with Predictions",
# Optional: Set dimensions to match input or your preference
height=480, # or None for auto
width=640 # or None for auto
)
# Connect the video stream directly to the prediction function
# The 'stream' event on gr.Video is triggered as new frames arrive from the webcam.
video_input.stream(
predict_live_frames, # The function to call for each frame
inputs=video_input, # Pass the video_input component itself as input
outputs=video_output # Update the video_output component
)
demo.launch()