Spaces:
Sleeping
Sleeping
File size: 8,292 Bytes
8f7ac8f 6c2ffca 8f7ac8f 6c2ffca 028d725 6c2ffca 028d725 6c2ffca 028d725 6c2ffca e1f8629 6c2ffca e1f8629 6c2ffca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import tempfile # For temporary file handling
# --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
# This is crucial because we need the model class definition to load weights.
class SmallVideoClassifier(torch.nn.Module):
def __init__(self, num_classes=2, num_frames=8):
super(SmallVideoClassifier, self).__init__()
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
# Load weights only if they are available (they should be for IMAGENET1K_V1)
# Use a check to prevent error if weights are not found in specific environments
try:
weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
except Exception:
print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
weights = None # Or provide specific default for your use case
self.feature_extractor = mobilenet_v3_small(weights=weights)
self.feature_extractor.classifier = torch.nn.Identity()
self.num_spatial_features = 576 # MobileNetV3-Small's final feature map size
self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.num_spatial_features, 512),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(512, num_classes)
)
def forward(self, pixel_values):
batch_size, num_frames, channels, height, width = pixel_values.shape
x = pixel_values.view(batch_size * num_frames, channels, height, width)
spatial_features = self.feature_extractor(x)
spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
logits = self.classifier(temporal_features)
return logits
# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
# Download config.json to get model parameters
print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
model_config = json.load(f)
NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)
# Define class labels (adjust if your dataset had different labels/order)
CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
# Initialize the model
device = torch.device("cpu") # Explicitly use CPU as requested
print(f"Using device: {device}")
model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
# Download model weights
print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval() # Set model to evaluation mode
print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# --- 4. Gradio Inference Function ---
def predict_video(video_path):
if video_path is None:
return None # Or raise an error, or return a placeholder video
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file {video_path}.")
raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")
# Get video properties
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create a temporary output video file
# Use tempfile to ensure proper cleanup on Hugging Face Spaces
temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
output_video_path = temp_output_file.name
temp_output_file.close() # Close the file handle as cv2.VideoWriter needs to open it
# Define the codec and create VideoWriter object
# For MP4, 'mp4v' is generally compatible.
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
print(f"Processing video: {video_path}")
print(f"Total frames: {total_frames}, FPS: {fps}")
print(f"Output video will be saved to: {output_video_path}")
frame_buffer = [] # To store NUM_FRAMES for each prediction batch
current_prediction_label = "Processing..." # Initial label
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break # End of video
frame_idx += 1
# Convert frame from BGR (OpenCV) to RGB (PIL/PyTorch)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
# Apply transformations and add to buffer
processed_frame = transform(pil_image) # shape: (C, H, W)
frame_buffer.append(processed_frame)
# Perform prediction when the buffer is full
if len(frame_buffer) == NUM_FRAMES:
# Stack the buffered frames and add a batch dimension
# Resulting shape: (1, NUM_FRAMES, C, H, W)
input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
probabilities = torch.softmax(outputs, dim=1)
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
# Reset buffer for the next non-overlapping window
frame_buffer = []
# Or, if you want sliding window (more continuous output but higher compute):
# frame_buffer = frame_buffer[1:] # e.g., slide by 1 frame
# Draw prediction text on the current frame
# The prediction will lag by NUM_FRAMES, as it's based on the previous batch.
# We display the last known prediction.
cv2.putText(frame, current_prediction_label, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2, cv2.LINE_AA)
# Write the processed frame to the output video
out.write(frame)
# Release resources
cap.release()
out.release()
print(f"Video processing complete. Output saved to: {output_video_path}")
return output_video_path # Gradio expects the path to the output video
# --- 5. Gradio Interface Setup ---
iface = gr.Interface(
fn=predict_video,
# Corrected: Removed 'type="filepath"' as it's not a valid argument for gr.Video
inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
outputs=gr.Video(label="Processed Video with Predictions"),
title="Real-time Violence Detection with SmallVideoClassifier",
description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
allow_flagging="never", # Disable flagging on Hugging Face Spaces
examples=[
# You can provide example video URLs or paths if you have them publicly available
# e.g., "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
]
)
iface.launch() |