owinymarvin's picture
latest changes
aedc519
import torch
import gradio as gr
import json
import urllib
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
CenterCropVideo,
NormalizeVideo,
)
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample,
UniformCropVideo
)
import numpy as np # Explicitly add numpy import
# Choose the `slowfast_r50` model
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
# Set to CPU since you don't have a GPU
device = "cpu"
model = model.eval()
model = model.to(device)
# --- Class Name Loading (from notebook) ---
json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
json_filename = "kinetics_classnames.json"
try:
urllib.URLopener().retrieve(json_url, json_filename)
except:
urllib.request.urlretrieve(json_url, json_filename)
with open(json_filename, "r") as f:
kinetics_classnames = json.load(f)
kinetics_id_to_classname = {}
for k, v in kinetics_classnames.items():
kinetics_id_to_classname[v] = str(k).replace('"', "")
# --- Define Input Transform (from notebook) ---
side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
slowfast_alpha = 4
# num_clips = 10 # Not used in inference function
# num_crops = 3 # Not used in inference function
class PackPathway(torch.nn.Module):
"""
Transform for converting video frames as a list of tensors.
"""
def __init__(self):
super().__init__()
def forward(self, frames: torch.Tensor):
fast_pathway = frames
slow_pathway = torch.index_select(
frames,
1,
torch.linspace(
0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
).long(),
)
frame_list = [slow_pathway, fast_pathway]
return frame_list
transform = ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames),
Lambda(lambda x: x/255.0),
NormalizeVideo(mean, std),
ShortSideScale(
size=side_size
),
CenterCropVideo(crop_size),
PackPathway()
]
),
)
clip_duration = (num_frames * sampling_rate)/frames_per_second
# Download example video (for local testing and for Gradio examples)
url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
video_path = 'archery.mp4'
try: urllib.URLopener().retrieve(url_link, video_path)
except: urllib.request.urlretrieve(url_link, video_path)
def inference(in_vid):
if in_vid is None:
return "Please upload a video or use the webcam."
try:
# Initialize an EncodedVideo helper class and load the video
video = EncodedVideo.from_path(in_vid)
# Ensure we have enough frames for the clip duration
if video.duration < clip_duration:
return f"Video is too short. Minimum duration is {clip_duration:.2f} seconds."
# Select the duration of the clip to load by specifying the start and end duration
start_sec = 0
end_sec = start_sec + clip_duration
# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
# Apply a transform to normalize the video input
video_data = transform(video_data)
# Move the inputs to the desired device
inputs = video_data["video"]
inputs = [i.to(device)[None, ...] for i in inputs]
# Pass the input clip through the model
with torch.no_grad(): # Ensure no gradient computation for inference
preds = model(inputs)
# Get the predicted classes
post_act = torch.nn.Softmax(dim=1)
preds = post_act(preds)
pred_classes = preds.topk(k=5).indices[0]
# Map the predicted classes to the label names
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
return "Top 5 predicted labels: %s" % ", ".join(pred_class_names)
except Exception as e:
# Catch common errors like video decoding issues or insufficient frames
return f"An error occurred during inference: {e}"
# --- UPDATED GRADIO INTERFACE SYNTAX ---
# Removed gr.inputs and gr.outputs
inputs_gradio = gr.Video(label="Upload Video or Use Webcam", sources=["upload", "webcam"], format="mp4")
outputs_gradio = gr.Textbox(label="Top 5 Predicted Labels")
title = "PyTorchVideo SlowFast Action Recognition"
description = """
Demo for PyTorchVideo's SlowFast model, pretrained on the Kinetics 400 dataset for action recognition.
Upload your video or use your webcam to classify the action.
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982' target='_blank'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo' target='_blank'>PyTorchVideo GitHub Repo</a></p>"
examples = [
[video_path] # Use the downloaded archery.mp4 as an example
]
gr.Interface(
fn=inference,
inputs=inputs_gradio,
outputs=outputs_gradio,
title=title,
description=description,
article=article,
examples=examples,
analytics_enabled=False
).launch()