Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import json | |
import urllib | |
from torchvision.transforms import Compose, Lambda | |
from torchvision.transforms._transforms_video import ( | |
CenterCropVideo, | |
NormalizeVideo, | |
) | |
from pytorchvideo.data.encoded_video import EncodedVideo | |
from pytorchvideo.transforms import ( | |
ApplyTransformToKey, | |
ShortSideScale, | |
UniformTemporalSubsample, | |
UniformCropVideo | |
) | |
import numpy as np # Explicitly add numpy import | |
# Choose the `slowfast_r50` model | |
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True) | |
# Set to CPU since you don't have a GPU | |
device = "cpu" | |
model = model.eval() | |
model = model.to(device) | |
# --- Class Name Loading (from notebook) --- | |
json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json" | |
json_filename = "kinetics_classnames.json" | |
try: | |
urllib.URLopener().retrieve(json_url, json_filename) | |
except: | |
urllib.request.urlretrieve(json_url, json_filename) | |
with open(json_filename, "r") as f: | |
kinetics_classnames = json.load(f) | |
kinetics_id_to_classname = {} | |
for k, v in kinetics_classnames.items(): | |
kinetics_id_to_classname[v] = str(k).replace('"', "") | |
# --- Define Input Transform (from notebook) --- | |
side_size = 256 | |
mean = [0.45, 0.45, 0.45] | |
std = [0.225, 0.225, 0.225] | |
crop_size = 256 | |
num_frames = 32 | |
sampling_rate = 2 | |
frames_per_second = 30 | |
slowfast_alpha = 4 | |
# num_clips = 10 # Not used in inference function | |
# num_crops = 3 # Not used in inference function | |
class PackPathway(torch.nn.Module): | |
""" | |
Transform for converting video frames as a list of tensors. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, frames: torch.Tensor): | |
fast_pathway = frames | |
slow_pathway = torch.index_select( | |
frames, | |
1, | |
torch.linspace( | |
0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha | |
).long(), | |
) | |
frame_list = [slow_pathway, fast_pathway] | |
return frame_list | |
transform = ApplyTransformToKey( | |
key="video", | |
transform=Compose( | |
[ | |
UniformTemporalSubsample(num_frames), | |
Lambda(lambda x: x/255.0), | |
NormalizeVideo(mean, std), | |
ShortSideScale( | |
size=side_size | |
), | |
CenterCropVideo(crop_size), | |
PackPathway() | |
] | |
), | |
) | |
clip_duration = (num_frames * sampling_rate)/frames_per_second | |
# Download example video (for local testing and for Gradio examples) | |
url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4" | |
video_path = 'archery.mp4' | |
try: urllib.URLopener().retrieve(url_link, video_path) | |
except: urllib.request.urlretrieve(url_link, video_path) | |
def inference(in_vid): | |
if in_vid is None: | |
return "Please upload a video or use the webcam." | |
try: | |
# Initialize an EncodedVideo helper class and load the video | |
video = EncodedVideo.from_path(in_vid) | |
# Ensure we have enough frames for the clip duration | |
if video.duration < clip_duration: | |
return f"Video is too short. Minimum duration is {clip_duration:.2f} seconds." | |
# Select the duration of the clip to load by specifying the start and end duration | |
start_sec = 0 | |
end_sec = start_sec + clip_duration | |
# Load the desired clip | |
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) | |
# Apply a transform to normalize the video input | |
video_data = transform(video_data) | |
# Move the inputs to the desired device | |
inputs = video_data["video"] | |
inputs = [i.to(device)[None, ...] for i in inputs] | |
# Pass the input clip through the model | |
with torch.no_grad(): # Ensure no gradient computation for inference | |
preds = model(inputs) | |
# Get the predicted classes | |
post_act = torch.nn.Softmax(dim=1) | |
preds = post_act(preds) | |
pred_classes = preds.topk(k=5).indices[0] | |
# Map the predicted classes to the label names | |
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes] | |
return "Top 5 predicted labels: %s" % ", ".join(pred_class_names) | |
except Exception as e: | |
# Catch common errors like video decoding issues or insufficient frames | |
return f"An error occurred during inference: {e}" | |
# --- UPDATED GRADIO INTERFACE SYNTAX --- | |
# Removed gr.inputs and gr.outputs | |
inputs_gradio = gr.Video(label="Upload Video or Use Webcam", sources=["upload", "webcam"], format="mp4") | |
outputs_gradio = gr.Textbox(label="Top 5 Predicted Labels") | |
title = "PyTorchVideo SlowFast Action Recognition" | |
description = """ | |
Demo for PyTorchVideo's SlowFast model, pretrained on the Kinetics 400 dataset for action recognition. | |
Upload your video or use your webcam to classify the action. | |
""" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982' target='_blank'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo' target='_blank'>PyTorchVideo GitHub Repo</a></p>" | |
examples = [ | |
[video_path] # Use the downloaded archery.mp4 as an example | |
] | |
gr.Interface( | |
fn=inference, | |
inputs=inputs_gradio, | |
outputs=outputs_gradio, | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
analytics_enabled=False | |
).launch() |