Spaces:
Sleeping
Sleeping
File size: 6,851 Bytes
8f7ac8f 6c2ffca 429fbf1 6c2ffca 429fbf1 6c2ffca 4ad907b 6c2ffca 4ad907b 6c2ffca 429fbf1 6c2ffca 8f7ac8f 4ad907b 6c2ffca 429fbf1 998f789 429fbf1 998f789 429fbf1 89b0c64 429fbf1 89b0c64 429fbf1 89b0c64 429fbf1 89b0c64 6c2ffca 89b0c64 4ad907b 429fbf1 89b0c64 028d725 89b0c64 429fbf1 89b0c64 429fbf1 89b0c64 429fbf1 d276007 429fbf1 d276007 429fbf1 d276007 429fbf1 d276007 429fbf1 89b0c64 429fbf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import time
# --- 1. Define Model Architecture ---
class SmallVideoClassifier(torch.nn.Module):
def __init__(self, num_classes=2, num_frames=8):
super(SmallVideoClassifier, self).__init__()
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
try:
weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
except Exception:
print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
weights = None
self.feature_extractor = mobilenet_v3_small(weights=weights)
self.feature_extractor.classifier = torch.nn.Identity()
self.num_spatial_features = 576
self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.num_spatial_features, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(512, num_classes)
)
def forward(self, pixel_values):
batch_size, num_frames, channels, height, width = pixel_values.shape
x = pixel_values.view(batch_size * num_frames, channels, height, width)
spatial_features = self.feature_extractor(x)
spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
logits = self.classifier(temporal_features)
return logits
# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
model_config = json.load(f)
NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)
CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
device = torch.device("cpu")
print(f"Using device: {device}")
model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval()
print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- Global state for the generator function ---
frame_buffer = []
current_prediction_label = "Initializing..."
current_probabilities = {label: 0.0 for label in CLASS_LABELS}
# --- 4. Gradio Live Inference Function (Generator) ---
def predict_live_frames(input_frame):
global frame_buffer, current_prediction_label, current_probabilities
if input_frame is None:
dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
yield dummy_frame
return
pil_image = Image.fromarray(input_frame)
processed_frame_tensor = transform(pil_image)
frame_buffer.append(processed_frame_tensor)
slide_window_by = 1
if len(frame_buffer) >= NUM_FRAMES:
input_tensor = torch.stack(frame_buffer[-NUM_FRAMES:], dim=0).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
frame_buffer = frame_buffer[slide_window_by:]
display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
# Draw the main prediction label
text_color = (0, 255, 0) # Green (BGR)
text_outline_color = (0, 0, 0) # Black
font_scale = 1.0
font_thickness = 2
cv2.putText(display_frame, current_prediction_label, (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
cv2.putText(display_frame, current_prediction_label, (10, 40),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
# Draw probabilities for all classes
y_offset = 80
for label, prob in current_probabilities.items():
prob_text = f"{label}: {prob:.2f}"
cv2.putText(display_frame, prob_text, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
cv2.putText(display_frame, prob_text, (10, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA)
y_offset += 30
yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
# --- 5. Gradio Blocks Interface Setup ---
with gr.Blocks(
title="Real-time Violence Detection", # Title for the browser tab
# REMOVED: theme=gr.themes.Default(primary_hue=gr.Color(...)) to fix the AttributeError
) as demo:
gr.Markdown(
"""
# 🎬 Real-time Violence Detection
**Live Feed with Constant Predictions**
This model analyzes your live webcam feed for violence, displaying the predicted class and probabilities on the screen.
Please grant webcam access when prompted by your browser.
"""
)
with gr.Row():
video_input = gr.Video(
sources=["webcam"],
streaming=True,
label="Live Webcam Feed",
height=480,
width=640
)
video_output = gr.Image(
type="numpy",
label="Processed Feed with Predictions",
height=480,
width=640
)
video_input.stream(
predict_live_frames,
inputs=video_input,
outputs=video_output
)
demo.launch() |