File size: 6,851 Bytes
8f7ac8f
 
 
 
6c2ffca
 
 
 
 
429fbf1
6c2ffca
429fbf1
6c2ffca
 
 
 
 
 
 
 
4ad907b
6c2ffca
 
 
4ad907b
6c2ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429fbf1
6c2ffca
 
 
 
 
 
 
8f7ac8f
4ad907b
6c2ffca
 
 
 
 
 
 
 
 
 
429fbf1
 
998f789
429fbf1
998f789
429fbf1
89b0c64
429fbf1
89b0c64
 
 
 
 
429fbf1
89b0c64
 
 
 
 
429fbf1
 
 
 
89b0c64
 
 
 
 
6c2ffca
89b0c64
 
4ad907b
429fbf1
 
89b0c64
 
 
 
 
 
 
028d725
89b0c64
 
 
 
 
429fbf1
 
89b0c64
 
 
 
 
429fbf1
 
89b0c64
 
 
429fbf1
 
 
d276007
429fbf1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d276007
 
429fbf1
 
 
 
 
d276007
 
429fbf1
 
 
d276007
 
 
429fbf1
89b0c64
429fbf1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
import torch
import cv2
import numpy as np
import os
import json
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import time

# --- 1. Define Model Architecture ---
class SmallVideoClassifier(torch.nn.Module):
    def __init__(self, num_classes=2, num_frames=8):
        super(SmallVideoClassifier, self).__init__()
        from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
        try:
            weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
        except Exception:
            print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
            weights = None

        self.feature_extractor = mobilenet_v3_small(weights=weights)
        self.feature_extractor.classifier = torch.nn.Identity()
        self.num_spatial_features = 576
        self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.num_spatial_features, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(512, num_classes)
        )

    def forward(self, pixel_values):
        batch_size, num_frames, channels, height, width = pixel_values.shape
        x = pixel_values.view(batch_size * num_frames, channels, height, width)
        spatial_features = self.feature_extractor(x)
        spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
        temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
        logits = self.classifier(temporal_features)
        return logits

# --- 2. Configuration and Model Loading ---
HF_USERNAME = "owinymarvin"
NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"

print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
with open(config_path, 'r') as f:
    model_config = json.load(f)

NUM_FRAMES = model_config.get('num_frames', 8)
IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
NUM_CLASSES = model_config.get('num_classes', 2)

CLASS_LABELS = ["Non-violence", "Violence"]
if NUM_CLASSES != len(CLASS_LABELS):
    print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")

device = torch.device("cpu")
print(f"Using device: {device}")

model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)

print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval()

print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")

# --- 3. Define Preprocessing Transform ---
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- Global state for the generator function ---
frame_buffer = []
current_prediction_label = "Initializing..."
current_probabilities = {label: 0.0 for label in CLASS_LABELS}

# --- 4. Gradio Live Inference Function (Generator) ---
def predict_live_frames(input_frame):
    global frame_buffer, current_prediction_label, current_probabilities

    if input_frame is None:
        dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
        cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        yield dummy_frame
        return 
    
    pil_image = Image.fromarray(input_frame)
    processed_frame_tensor = transform(pil_image)
    frame_buffer.append(processed_frame_tensor)

    slide_window_by = 1 

    if len(frame_buffer) >= NUM_FRAMES:
        input_tensor = torch.stack(frame_buffer[-NUM_FRAMES:], dim=0).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            predicted_class_idx = torch.argmax(probabilities, dim=1).item()
            
            current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
            current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
        
        frame_buffer = frame_buffer[slide_window_by:]

    display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)

    # Draw the main prediction label
    text_color = (0, 255, 0) # Green (BGR)
    text_outline_color = (0, 0, 0) # Black
    font_scale = 1.0
    font_thickness = 2
    
    cv2.putText(display_frame, current_prediction_label, (10, 40),
                cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
    cv2.putText(display_frame, current_prediction_label, (10, 40),
                cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)

    # Draw probabilities for all classes
    y_offset = 80 
    for label, prob in current_probabilities.items():
        prob_text = f"{label}: {prob:.2f}"
        cv2.putText(display_frame, prob_text, (10, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
        cv2.putText(display_frame, prob_text, (10, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA)
        y_offset += 30 

    yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)

# --- 5. Gradio Blocks Interface Setup ---
with gr.Blocks(
    title="Real-time Violence Detection", # Title for the browser tab
    # REMOVED: theme=gr.themes.Default(primary_hue=gr.Color(...)) to fix the AttributeError
) as demo:
    gr.Markdown(
        """
        # 🎬 Real-time Violence Detection
        **Live Feed with Constant Predictions**

        This model analyzes your live webcam feed for violence, displaying the predicted class and probabilities on the screen.
        Please grant webcam access when prompted by your browser.
        """
    )
    
    with gr.Row():
        video_input = gr.Video(
            sources=["webcam"], 
            streaming=True, 
            label="Live Webcam Feed",
            height=480, 
            width=640 
        )
        
        video_output = gr.Image(
            type="numpy", 
            label="Processed Feed with Predictions",
            height=480, 
            width=640 
        )

    video_input.stream(
        predict_live_frames, 
        inputs=video_input,  
        outputs=video_output 
    )

demo.launch()