File size: 8,292 Bytes
8f7ac8f
 
 
 
6c2ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7ac8f
6c2ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
028d725
6c2ffca
 
 
028d725
6c2ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
028d725
6c2ffca
 
 
 
 
e1f8629
 
6c2ffca
 
 
 
 
 
e1f8629
6c2ffca
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import tempfile # For temporary file handling

# --- 1. Define Model Architecture (Copy from small_video_classifier.py) ---
# This is crucial because we need the model class definition to load weights.
class SmallVideoClassifier(torch.nn.Module):
    def __init__(self, num_classes=2, num_frames=8):
        super(SmallVideoClassifier, self).__init__()
        from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
        # Load weights only if they are available (they should be for IMAGENET1K_V1)
        # Use a check to prevent error if weights are not found in specific environments
        try:
            weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        except Exception:
            print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
            weights = None # Or provide specific default for your use case

        self.feature_extractor = mobilenet_v3_small(weights=weights)
        self.feature_extractor.classifier = torch.nn.Identity()
        self.num_spatial_features = 576 # MobileNetV3-Small's final feature map size
        self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.num_spatial_features, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(512, num_classes)
        )

    def forward(self, pixel_values):
        batch_size, num_frames, channels, height, width = pixel_values.shape
        x = pixel_values.view(batch_size * num_frames, channels, height, width)
        spatial_features = self.feature_extractor(x)
        spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
        temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
        logits = self.classifier(temporal_features)
        return logits

# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"

# Download config.json to get model parameters
print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
    model_config = json.load(f)

NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)

# Define class labels (adjust if your dataset had different labels/order)
CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
    print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")


# Initialize the model
device = torch.device("cpu") # Explicitly use CPU as requested
print(f"Using device: {device}")

model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)

# Download model weights
print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval() # Set model to evaluation mode

print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")

# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- 4. Gradio Inference Function ---
def predict_video(video_path):
    if video_path is None:
        return None # Or raise an error, or return a placeholder video

    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Could not open video file {video_path}.")
        raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")

    # Get video properties
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Create a temporary output video file
    # Use tempfile to ensure proper cleanup on Hugging Face Spaces
    temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
    output_video_path = temp_output_file.name
    temp_output_file.close() # Close the file handle as cv2.VideoWriter needs to open it

    # Define the codec and create VideoWriter object
    # For MP4, 'mp4v' is generally compatible.
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

    print(f"Processing video: {video_path}")
    print(f"Total frames: {total_frames}, FPS: {fps}")
    print(f"Output video will be saved to: {output_video_path}")

    frame_buffer = [] # To store NUM_FRAMES for each prediction batch
    current_prediction_label = "Processing..." # Initial label

    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break # End of video

        frame_idx += 1
        
        # Convert frame from BGR (OpenCV) to RGB (PIL/PyTorch)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)
        
        # Apply transformations and add to buffer
        processed_frame = transform(pil_image) # shape: (C, H, W)
        frame_buffer.append(processed_frame)

        # Perform prediction when the buffer is full
        if len(frame_buffer) == NUM_FRAMES:
            # Stack the buffered frames and add a batch dimension
            # Resulting shape: (1, NUM_FRAMES, C, H, W)
            input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)

            with torch.no_grad():
                outputs = model(input_tensor)
                probabilities = torch.softmax(outputs, dim=1)
                predicted_class_idx = torch.argmax(probabilities, dim=1).item()
                current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
            
            # Reset buffer for the next non-overlapping window
            frame_buffer = [] 
            # Or, if you want sliding window (more continuous output but higher compute):
            # frame_buffer = frame_buffer[1:] # e.g., slide by 1 frame

        # Draw prediction text on the current frame
        # The prediction will lag by NUM_FRAMES, as it's based on the previous batch.
        # We display the last known prediction.
        cv2.putText(frame, current_prediction_label, (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2, cv2.LINE_AA)

        # Write the processed frame to the output video
        out.write(frame)

    # Release resources
    cap.release()
    out.release()
    print(f"Video processing complete. Output saved to: {output_video_path}")
    
    return output_video_path # Gradio expects the path to the output video

# --- 5. Gradio Interface Setup ---
iface = gr.Interface(
    fn=predict_video,
    # Corrected: Removed 'type="filepath"' as it's not a valid argument for gr.Video
    inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
    outputs=gr.Video(label="Processed Video with Predictions"),
    title="Real-time Violence Detection with SmallVideoClassifier",
    description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
    allow_flagging="never", # Disable flagging on Hugging Face Spaces
    examples=[
        # You can provide example video URLs or paths if you have them publicly available
        # e.g., "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
    ]
)

iface.launch()