owinymarvin commited on
Commit
11e2014
·
1 Parent(s): 320fc26

latest changes

Browse files
Files changed (2) hide show
  1. app.py +124 -157
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,164 +1,131 @@
1
- import gradio as gr
2
  import torch
3
- import cv2
4
- import numpy as np
5
- import os
6
  import json
7
- from PIL import Image
8
- from torchvision import transforms
9
- from huggingface_hub import hf_hub_download
10
- import time
11
-
12
- # --- 1. Define Model Architecture ---
13
- class SmallVideoClassifier(torch.nn.Module):
14
- def __init__(self, num_classes=2, num_frames=8):
15
- super(SmallVideoClassifier, self).__init__()
16
- from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
17
- try:
18
- weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
19
- except Exception:
20
- print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
21
- weights = None
22
-
23
- self.feature_extractor = mobilenet_v3_small(weights=weights)
24
- self.feature_extractor.classifier = torch.nn.Identity()
25
- self.num_spatial_features = 576
26
- self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
27
- self.classifier = torch.nn.Sequential(
28
- torch.nn.Linear(self.num_spatial_features, 512),
29
- torch.nn.ReLU(),
30
- torch.nn.Dropout(0.2),
31
- torch.nn.Linear(512, num_classes)
32
- )
33
 
34
- def forward(self, pixel_values):
35
- batch_size, num_frames, channels, height, width = pixel_values.shape
36
- x = pixel_values.view(batch_size * num_frames, channels, height, width)
37
- spatial_features = self.feature_extractor(x)
38
- spatial_features = spatial_features.view(batch_size, num_frames, self.num_spatial_features)
39
- temporal_features = self.temporal_aggregator(spatial_features.permute(0, 2, 1)).squeeze(-1)
40
- logits = self.classifier(temporal_features)
41
- return logits
42
-
43
- # --- 2. Configuration and Model Loading ---
44
- HF_USERNAME = "owinymarvin"
45
- NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
46
- NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
47
-
48
- print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
49
- config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
50
- with open(config_path, 'r') as f:
51
- model_config = json.load(f)
52
-
53
- NUM_FRAMES = model_config.get('num_frames', 8)
54
- IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
55
- NUM_CLASSES = model_config.get('num_classes', 2)
56
-
57
- CLASS_LABELS = ["Non-violence", "Violence"]
58
- if NUM_CLASSES != len(CLASS_LABELS):
59
- print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
60
-
61
- device = torch.device("cpu")
62
- print(f"Using device: {device}")
63
-
64
- model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
65
-
66
- print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
67
- model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
68
- model.load_state_dict(torch.load(model_weights_path, map_location=device))
69
- model.to(device)
70
- model.eval()
71
-
72
- print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
73
-
74
- # --- 3. Define Preprocessing Transform ---
75
- transform = transforms.Compose([
76
- transforms.Resize(IMAGE_SIZE),
77
- transforms.ToTensor(),
78
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
- ])
80
-
81
- # --- Global state for the generator function ---
82
- frame_buffer = []
83
- current_prediction_label = "Initializing..."
84
- current_probabilities = {label: 0.0 for label in CLASS_LABELS}
85
-
86
- # --- 4. Gradio Live Inference Function (Generator) ---
87
- # This function will receive individual frames from the webcam as a NumPy array (H, W, C, RGB)
88
- def predict_live_frames(input_frame):
89
- global frame_buffer, current_prediction_label, current_probabilities
90
-
91
- if input_frame is None:
92
- dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
93
- cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
94
- yield dummy_frame
95
- return
96
-
97
- pil_image = Image.fromarray(input_frame)
98
- processed_frame_tensor = transform(pil_image)
99
- frame_buffer.append(processed_frame_tensor)
100
-
101
- slide_window_by = 1
102
-
103
- if len(frame_buffer) >= NUM_FRAMES:
104
- input_tensor = torch.stack(frame_buffer[-NUM_FRAMES:], dim=0).unsqueeze(0).to(device)
105
-
106
- with torch.no_grad():
107
- outputs = model(input_tensor)
108
- probabilities = torch.softmax(outputs, dim=1)
109
- predicted_class_idx = torch.argmax(probabilities, dim=1).item()
110
-
111
- current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
112
- current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
113
 
114
- frame_buffer = frame_buffer[slide_window_by:]
115
-
116
- display_frame = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
117
-
118
- # Draw the main prediction label
119
- text_color = (0, 255, 0) # Green (BGR)
120
- text_outline_color = (0, 0, 0) # Black
121
- font_scale = 1.0
122
- font_thickness = 2
123
-
124
- cv2.putText(display_frame, current_prediction_label, (10, 40),
125
- cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
126
- cv2.putText(display_frame, current_prediction_label, (10, 40),
127
- cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
128
-
129
- # Draw probabilities for all classes
130
- y_offset = 80
131
- for label, prob in current_probabilities.items():
132
- prob_text = f"{label}: {prob:.2f}"
133
- cv2.putText(display_frame, prob_text, (10, y_offset),
134
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, text_outline_color, 2, cv2.LINE_AA)
135
- cv2.putText(display_frame, prob_text, (10, y_offset),
136
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 1, cv2.LINE_AA)
137
- y_offset += 30
138
-
139
- yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
140
-
141
- # --- 5. Gradio Interface Setup (with hidden buttons) ---
142
- iface = gr.Interface(
143
- fn=predict_live_frames,
144
- # Input: Live webcam feed, configure for streaming
145
- inputs=gr.Video(sources=["webcam"], streaming=True, label="Live Webcam Feed"),
146
- # Output: Image component to display processed frames
147
- outputs=gr.Image(type="numpy", label="Processed Feed with Predictions"),
148
-
149
- title="Real-time Violence Detection with SmallVideoClassifier (Webcam)",
150
- description=(
151
- "This model analyzes your live webcam feed for violence, displaying the predicted class and probabilities on the screen. "
152
- "Please grant webcam access when prompted by your browser."
153
  ),
154
-
155
- # --- IMPORTANT: Hide the default submit/clear buttons ---
156
- submit_btn=None,
157
- clear_btn=None,
158
-
159
- allow_flagging="never",
160
- # No examples needed for live webcam
161
- examples=None
162
  )
163
 
164
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ # Choose the `slowfast_r50` model
3
+ model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
4
+ from typing import Dict
5
  import json
6
+ import urllib
7
+ from torchvision.transforms import Compose, Lambda
8
+ from torchvision.transforms._transforms_video import (
9
+ CenterCropVideo,
10
+ NormalizeVideo,
11
+ )
12
+ from pytorchvideo.data.encoded_video import EncodedVideo
13
+ from pytorchvideo.transforms import (
14
+ ApplyTransformToKey,
15
+ ShortSideScale,
16
+ UniformTemporalSubsample,
17
+ UniformCropVideo
18
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ import gradio as gr
21
+ # Set to GPU or CPU
22
+ device = "cpu"
23
+ model = model.eval()
24
+ model = model.to(device)
25
+ json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
26
+ json_filename = "kinetics_classnames.json"
27
+ try: urllib.URLopener().retrieve(json_url, json_filename)
28
+ except: urllib.request.urlretrieve(json_url, json_filename)
29
+ with open(json_filename, "r") as f:
30
+ kinetics_classnames = json.load(f)
31
+
32
+ # Create an id to label name mapping
33
+ kinetics_id_to_classname = {}
34
+ for k, v in kinetics_classnames.items():
35
+ kinetics_id_to_classname[v] = str(k).replace('"', "")
36
+ side_size = 256
37
+ mean = [0.45, 0.45, 0.45]
38
+ std = [0.225, 0.225, 0.225]
39
+ crop_size = 256
40
+ num_frames = 32
41
+ sampling_rate = 2
42
+ frames_per_second = 30
43
+ slowfast_alpha = 4
44
+ num_clips = 10
45
+ num_crops = 3
46
+
47
+ class PackPathway(torch.nn.Module):
48
+ """
49
+ Transform for converting video frames as a list of tensors.
50
+ """
51
+ def __init__(self):
52
+ super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def forward(self, frames: torch.Tensor):
55
+ fast_pathway = frames
56
+ # Perform temporal sampling from the fast pathway.
57
+ slow_pathway = torch.index_select(
58
+ frames,
59
+ 1,
60
+ torch.linspace(
61
+ 0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
62
+ ).long(),
63
+ )
64
+ frame_list = [slow_pathway, fast_pathway]
65
+ return frame_list
66
+
67
+ transform = ApplyTransformToKey(
68
+ key="video",
69
+ transform=Compose(
70
+ [
71
+ UniformTemporalSubsample(num_frames),
72
+ Lambda(lambda x: x/255.0),
73
+ NormalizeVideo(mean, std),
74
+ ShortSideScale(
75
+ size=side_size
76
+ ),
77
+ CenterCropVideo(crop_size),
78
+ PackPathway()
79
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  ),
 
 
 
 
 
 
 
 
81
  )
82
 
83
+ # The duration of the input clip is also specific to the model.
84
+ clip_duration = (num_frames * sampling_rate)/frames_per_second
85
+ url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
86
+ video_path = 'archery.mp4'
87
+ try: urllib.URLopener().retrieve(url_link, video_path)
88
+ except: urllib.request.urlretrieve(url_link, video_path)
89
+ # Select the duration of the clip to load by specifying the start and end duration
90
+ # The start_sec should correspond to where the action occurs in the video
91
+
92
+ def inference(in_vid):
93
+ start_sec = 0
94
+ end_sec = start_sec + clip_duration
95
+
96
+ # Initialize an EncodedVideo helper class and load the video
97
+ video = EncodedVideo.from_path(in_vid)
98
+
99
+ # Load the desired clip
100
+ video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
101
+
102
+ # Apply a transform to normalize the video input
103
+ video_data = transform(video_data)
104
+
105
+ # Move the inputs to the desired device
106
+ inputs = video_data["video"]
107
+ inputs = [i.to(device)[None, ...] for i in inputs]
108
+ # Pass the input clip through the model
109
+ preds = model(inputs)
110
+
111
+ # Get the predicted classes
112
+ post_act = torch.nn.Softmax(dim=1)
113
+ preds = post_act(preds)
114
+ pred_classes = preds.topk(k=5).indices[0]
115
+
116
+ # Map the predicted classes to the label names
117
+ pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
118
+ return "%s" % ", ".join(pred_class_names)
119
+
120
+ inputs = gr.inputs.Video(label="Input Video")
121
+ outputs = gr.outputs.Textbox(label="Top 5 predicted labels")
122
+
123
+ title = "SLOWFAST"
124
+ description = "demo for SLOWFAST, SlowFast networks pretrained on the Kinetics 400 dataset. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
125
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo'>Github Repo</a></p>"
126
+
127
+ examples = [
128
+ ['archery.mp4']
129
+ ]
130
+
131
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(debug=True)
requirements.txt CHANGED
@@ -4,4 +4,6 @@ opencv-python-headless # Use headless for server environments to avoid GUI depe
4
  gradio
5
  huggingface_hub
6
  Pillow
7
- numpy
 
 
 
4
  gradio
5
  huggingface_hub
6
  Pillow
7
+ numpy
8
+ av
9
+ fvcore