import gradio as gr import torch import cv2 import numpy as np import os import json from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download import tempfile # For temporary file handling # --- 1. Define Model Architecture (Copy from small_video_classifier.py) --- # This is crucial because we need the model class definition to load weights. class SmallVideoClassifier(torch.nn.Module): def __init__(self, num_classes=2, num_frames=8): super(SmallVideoClassifier, self).__init__() from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights try: weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1 except Exception: print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.") weights = None self.feature_extractor = mobilenet_v3_small(weights=weights) self.feature_extractor.classifier = torch.nn.Identity() self.num_spatial_features = 576 self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1) self.classifier = torch.nn.Sequential( torch.nn.Linear(self.num_spatial_features, 512), torch.nn.ReLU(), torch.nn.Dropout(0.2), torch.nn.Linear(512, num_classes) ) def forward(self, pixel_values): batch_size, num_frames, channels, height, width = pixel_values.shape x = pixel_values.view(batch_size * num_frames, channels, height, width) spatial_features = self.feature_extractor(x) spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features) temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1) logits = self.classifier(temporal_features) return logits # --- 2. Configuration and Model Loading --- HF_USERNAME = "owinymarvin" NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector" NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}" print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...") config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json") with open(config_path, 'r') as f: model_config = json.load(f) NUM_FRAMES = model_config.get('num_frames', 8) IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224])) NUM_CLASSES = model_config.get('num_classes', 2) CLASS_LABELS = ["Non-violence", "Violence"] if NUM_CLASSES != len(CLASS_LABELS): print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.") device = torch.device("cpu") print(f"Using device: {device}") model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES) print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...") model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth") model.load_state_dict(torch.load(model_weights_path, map_location=device)) model.to(device) model.eval() print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.") # --- 3. Define Preprocessing Transform --- transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --- 4. Gradio Inference Function --- def predict_video(video_path): if video_path is None: return None cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print(f"Error: Could not open video file {video_path}.") raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.") frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) # Ensure FPS is not zero to avoid division by zero errors, default to 25 if needed if fps <= 0: fps = 25.0 print(f"Warning: Original video FPS was 0 or less, defaulting to {fps}.") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) output_video_path = temp_output_file.name temp_output_file.close() # --- CHANGED: Use XVID codec for better browser compatibility --- # This might prevent Gradio's internal re-encoding. fourcc = cv2.VideoWriter_fourcc(*'XVID') out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) print(f"Processing video: {video_path}") print(f"Total frames: {total_frames}, FPS: {fps}") print(f"Output video will be saved to: {output_video_path}") frame_buffer = [] current_prediction_label = "Processing..." frame_idx = 0 while True: ret, frame = cap.read() if not ret: break frame_idx += 1 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) processed_frame = transform(pil_image) frame_buffer.append(processed_frame) if len(frame_buffer) == NUM_FRAMES: input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.softmax(outputs, dim=1) predicted_class_idx = torch.argmax(probabilities, dim=1).item() current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})" frame_buffer = [] # If you want a sliding window, you would do something like: # frame_buffer = frame_buffer[int(NUM_FRAMES * 0.5):] # Slide by half the window size # Draw prediction text on the current frame # Ensure text color is clearly visible (e.g., white or bright green) # Add a black outline for better readability text_color = (0, 255, 0) # Green (BGR format for OpenCV) text_outline_color = (0, 0, 0) # Black font_scale = 1.0 # Increased font size font_thickness = 2 # Draw outline first for better readability cv2.putText(frame, current_prediction_label, (10, 40), # Slightly lower position cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA) # Draw actual text cv2.putText(frame, current_prediction_label, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA) out.write(frame) cap.release() out.release() print(f"Video processing complete. Output saved to: {output_video_path}") return output_video_path # --- 5. Gradio Interface Setup --- iface = gr.Interface( fn=predict_video, inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"), outputs=gr.Video(label="Processed Video with Predictions"), title="Real-time Violence Detection with SmallVideoClassifier", description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.", allow_flagging="never", examples=[ # Add example videos here for easier testing and demonstration # E.g., a sample video that's publicly accessible: # "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4" ] ) iface.launch()