Spaces:
Sleeping
Sleeping
Commit
·
11e2014
1
Parent(s):
320fc26
latest changes
Browse files- app.py +124 -157
- requirements.txt +3 -1
app.py
CHANGED
@@ -1,164 +1,131 @@
|
|
1 |
-
import gradio as gr
|
2 |
import torch
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
import json
|
7 |
-
|
8 |
-
from torchvision import
|
9 |
-
from
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
|
21 |
-
weights = None
|
22 |
-
|
23 |
-
self.feature_extractor = mobilenet_v3_small(weights=weights)
|
24 |
-
self.feature_extractor.classifier = torch.nn.Identity()
|
25 |
-
self.num_spatial_features = 576
|
26 |
-
self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
|
27 |
-
self.classifier = torch.nn.Sequential(
|
28 |
-
torch.nn.Linear(self.num_spatial_features, 512),
|
29 |
-
torch.nn.ReLU(),
|
30 |
-
torch.nn.Dropout(0.2),
|
31 |
-
torch.nn.Linear(512, num_classes)
|
32 |
-
)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
|
68 |
-
model.load_state_dict(torch.load(model_weights_path, map_location=device))
|
69 |
-
model.to(device)
|
70 |
-
model.eval()
|
71 |
-
|
72 |
-
print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
|
73 |
-
|
74 |
-
# --- 3. Define Preprocessing Transform ---
|
75 |
-
transform = transforms.Compose([
|
76 |
-
transforms.Resize(IMAGE_SIZE),
|
77 |
-
transforms.ToTensor(),
|
78 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
79 |
-
])
|
80 |
-
|
81 |
-
# --- Global state for the generator function ---
|
82 |
-
frame_buffer = []
|
83 |
-
current_prediction_label = "Initializing..."
|
84 |
-
current_probabilities = {label: 0.0 for label in CLASS_LABELS}
|
85 |
-
|
86 |
-
# --- 4. Gradio Live Inference Function (Generator) ---
|
87 |
-
# This function will receive individual frames from the webcam as a NumPy array (H, W, C, RGB)
|
88 |
-
def predict_live_frames(input_frame):
|
89 |
-
global frame_buffer, current_prediction_label, current_probabilities
|
90 |
-
|
91 |
-
if input_frame is None:
|
92 |
-
dummy_frame = np.zeros((200, 400, 3), dtype=np.uint8)
|
93 |
-
cv2.putText(dummy_frame, "Waiting for webcam input...", (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
94 |
-
yield dummy_frame
|
95 |
-
return
|
96 |
-
|
97 |
-
pil_image = Image.fromarray(input_frame)
|
98 |
-
processed_frame_tensor = transform(pil_image)
|
99 |
-
frame_buffer.append(processed_frame_tensor)
|
100 |
-
|
101 |
-
slide_window_by = 1
|
102 |
-
|
103 |
-
if len(frame_buffer) >= NUM_FRAMES:
|
104 |
-
input_tensor = torch.stack(frame_buffer[-NUM_FRAMES:], dim=0).unsqueeze(0).to(device)
|
105 |
-
|
106 |
-
with torch.no_grad():
|
107 |
-
outputs = model(input_tensor)
|
108 |
-
probabilities = torch.softmax(outputs, dim=1)
|
109 |
-
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
|
110 |
-
|
111 |
-
current_prediction_label = f"Class: {CLASS_LABELS[predicted_class_idx]}"
|
112 |
-
current_probabilities = {CLASS_LABELS[i]: prob.item() for i, prob in enumerate(probabilities[0])}
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
# --- 5. Gradio Interface Setup (with hidden buttons) ---
|
142 |
-
iface = gr.Interface(
|
143 |
-
fn=predict_live_frames,
|
144 |
-
# Input: Live webcam feed, configure for streaming
|
145 |
-
inputs=gr.Video(sources=["webcam"], streaming=True, label="Live Webcam Feed"),
|
146 |
-
# Output: Image component to display processed frames
|
147 |
-
outputs=gr.Image(type="numpy", label="Processed Feed with Predictions"),
|
148 |
-
|
149 |
-
title="Real-time Violence Detection with SmallVideoClassifier (Webcam)",
|
150 |
-
description=(
|
151 |
-
"This model analyzes your live webcam feed for violence, displaying the predicted class and probabilities on the screen. "
|
152 |
-
"Please grant webcam access when prompted by your browser."
|
153 |
),
|
154 |
-
|
155 |
-
# --- IMPORTANT: Hide the default submit/clear buttons ---
|
156 |
-
submit_btn=None,
|
157 |
-
clear_btn=None,
|
158 |
-
|
159 |
-
allow_flagging="never",
|
160 |
-
# No examples needed for live webcam
|
161 |
-
examples=None
|
162 |
)
|
163 |
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
# Choose the `slowfast_r50` model
|
3 |
+
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
|
4 |
+
from typing import Dict
|
5 |
import json
|
6 |
+
import urllib
|
7 |
+
from torchvision.transforms import Compose, Lambda
|
8 |
+
from torchvision.transforms._transforms_video import (
|
9 |
+
CenterCropVideo,
|
10 |
+
NormalizeVideo,
|
11 |
+
)
|
12 |
+
from pytorchvideo.data.encoded_video import EncodedVideo
|
13 |
+
from pytorchvideo.transforms import (
|
14 |
+
ApplyTransformToKey,
|
15 |
+
ShortSideScale,
|
16 |
+
UniformTemporalSubsample,
|
17 |
+
UniformCropVideo
|
18 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
import gradio as gr
|
21 |
+
# Set to GPU or CPU
|
22 |
+
device = "cpu"
|
23 |
+
model = model.eval()
|
24 |
+
model = model.to(device)
|
25 |
+
json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
|
26 |
+
json_filename = "kinetics_classnames.json"
|
27 |
+
try: urllib.URLopener().retrieve(json_url, json_filename)
|
28 |
+
except: urllib.request.urlretrieve(json_url, json_filename)
|
29 |
+
with open(json_filename, "r") as f:
|
30 |
+
kinetics_classnames = json.load(f)
|
31 |
+
|
32 |
+
# Create an id to label name mapping
|
33 |
+
kinetics_id_to_classname = {}
|
34 |
+
for k, v in kinetics_classnames.items():
|
35 |
+
kinetics_id_to_classname[v] = str(k).replace('"', "")
|
36 |
+
side_size = 256
|
37 |
+
mean = [0.45, 0.45, 0.45]
|
38 |
+
std = [0.225, 0.225, 0.225]
|
39 |
+
crop_size = 256
|
40 |
+
num_frames = 32
|
41 |
+
sampling_rate = 2
|
42 |
+
frames_per_second = 30
|
43 |
+
slowfast_alpha = 4
|
44 |
+
num_clips = 10
|
45 |
+
num_crops = 3
|
46 |
+
|
47 |
+
class PackPathway(torch.nn.Module):
|
48 |
+
"""
|
49 |
+
Transform for converting video frames as a list of tensors.
|
50 |
+
"""
|
51 |
+
def __init__(self):
|
52 |
+
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
def forward(self, frames: torch.Tensor):
|
55 |
+
fast_pathway = frames
|
56 |
+
# Perform temporal sampling from the fast pathway.
|
57 |
+
slow_pathway = torch.index_select(
|
58 |
+
frames,
|
59 |
+
1,
|
60 |
+
torch.linspace(
|
61 |
+
0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
|
62 |
+
).long(),
|
63 |
+
)
|
64 |
+
frame_list = [slow_pathway, fast_pathway]
|
65 |
+
return frame_list
|
66 |
+
|
67 |
+
transform = ApplyTransformToKey(
|
68 |
+
key="video",
|
69 |
+
transform=Compose(
|
70 |
+
[
|
71 |
+
UniformTemporalSubsample(num_frames),
|
72 |
+
Lambda(lambda x: x/255.0),
|
73 |
+
NormalizeVideo(mean, std),
|
74 |
+
ShortSideScale(
|
75 |
+
size=side_size
|
76 |
+
),
|
77 |
+
CenterCropVideo(crop_size),
|
78 |
+
PackPathway()
|
79 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
)
|
82 |
|
83 |
+
# The duration of the input clip is also specific to the model.
|
84 |
+
clip_duration = (num_frames * sampling_rate)/frames_per_second
|
85 |
+
url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
|
86 |
+
video_path = 'archery.mp4'
|
87 |
+
try: urllib.URLopener().retrieve(url_link, video_path)
|
88 |
+
except: urllib.request.urlretrieve(url_link, video_path)
|
89 |
+
# Select the duration of the clip to load by specifying the start and end duration
|
90 |
+
# The start_sec should correspond to where the action occurs in the video
|
91 |
+
|
92 |
+
def inference(in_vid):
|
93 |
+
start_sec = 0
|
94 |
+
end_sec = start_sec + clip_duration
|
95 |
+
|
96 |
+
# Initialize an EncodedVideo helper class and load the video
|
97 |
+
video = EncodedVideo.from_path(in_vid)
|
98 |
+
|
99 |
+
# Load the desired clip
|
100 |
+
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
|
101 |
+
|
102 |
+
# Apply a transform to normalize the video input
|
103 |
+
video_data = transform(video_data)
|
104 |
+
|
105 |
+
# Move the inputs to the desired device
|
106 |
+
inputs = video_data["video"]
|
107 |
+
inputs = [i.to(device)[None, ...] for i in inputs]
|
108 |
+
# Pass the input clip through the model
|
109 |
+
preds = model(inputs)
|
110 |
+
|
111 |
+
# Get the predicted classes
|
112 |
+
post_act = torch.nn.Softmax(dim=1)
|
113 |
+
preds = post_act(preds)
|
114 |
+
pred_classes = preds.topk(k=5).indices[0]
|
115 |
+
|
116 |
+
# Map the predicted classes to the label names
|
117 |
+
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
|
118 |
+
return "%s" % ", ".join(pred_class_names)
|
119 |
+
|
120 |
+
inputs = gr.inputs.Video(label="Input Video")
|
121 |
+
outputs = gr.outputs.Textbox(label="Top 5 predicted labels")
|
122 |
+
|
123 |
+
title = "SLOWFAST"
|
124 |
+
description = "demo for SLOWFAST, SlowFast networks pretrained on the Kinetics 400 dataset. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
|
125 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo'>Github Repo</a></p>"
|
126 |
+
|
127 |
+
examples = [
|
128 |
+
['archery.mp4']
|
129 |
+
]
|
130 |
+
|
131 |
+
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(debug=True)
|
requirements.txt
CHANGED
@@ -4,4 +4,6 @@ opencv-python-headless # Use headless for server environments to avoid GUI depe
|
|
4 |
gradio
|
5 |
huggingface_hub
|
6 |
Pillow
|
7 |
-
numpy
|
|
|
|
|
|
4 |
gradio
|
5 |
huggingface_hub
|
6 |
Pillow
|
7 |
+
numpy
|
8 |
+
av
|
9 |
+
fvcore
|