owinymarvin commited on
Commit
8f7ac8f
·
1 Parent(s): 482f1b7

first commit

Browse files
Files changed (2) hide show
  1. app.py +104 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, TimesformerForVideoClassification
4
+ import cv2
5
+ from PIL import Image
6
+ import numpy as np
7
+ import time
8
+ from collections import deque
9
+
10
+ # --- Configuration ---
11
+ # Your Hugging Face model repository ID
12
+ HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
+
14
+ # These must match the values used during your training
15
+ NUM_FRAMES = 16
16
+ TARGET_IMAGE_HEIGHT = 224
17
+ TARGET_IMAGE_WIDTH = 224
18
+
19
+ # --- Load Model and Processor ---
20
+ print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
21
+ try:
22
+ processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
23
+ model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
24
+ except Exception as e:
25
+ print(f"Error loading model from Hugging Face Hub: {e}")
26
+ # Handle error - exit or raise exception for Space to fail gracefully
27
+ exit()
28
+
29
+ model.eval() # Set model to evaluation mode
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.to(device)
32
+ print(f"Model loaded successfully on {device}.")
33
+ print(f"Model's class labels: {model.config.id2label}")
34
+
35
+ # Initialize a global buffer for frames for the session
36
+ # Use a deque for efficient appending/popping from both ends
37
+ frame_buffer = deque(maxlen=NUM_FRAMES)
38
+ last_inference_time = time.time()
39
+ inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
40
+ current_prediction_text = "Buffering frames..."
41
+
42
+ def predict_video_frame(image_np_array):
43
+ global frame_buffer, last_inference_time, current_prediction_text
44
+
45
+ # Gradio sends frames as numpy arrays (RGB)
46
+ pil_image = Image.fromarray(image_np_array)
47
+ frame_buffer.append(pil_image)
48
+
49
+ current_time = time.time()
50
+
51
+ # Only perform inference if we have enough frames and it's time for a prediction
52
+ if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
53
+ last_inference_time = current_time
54
+
55
+ # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
56
+ processed_input = processor(images=list(frame_buffer), return_tensors="pt")
57
+ pixel_values = processed_input.pixel_values.to(device)
58
+
59
+ with torch.no_grad():
60
+ outputs = model(pixel_values)
61
+ logits = outputs.logits
62
+
63
+ predicted_class_id = logits.argmax(-1).item()
64
+ predicted_label = model.config.id2label[predicted_class_id]
65
+ confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
66
+
67
+ current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
68
+ print(current_prediction_text) # Print to Space logs
69
+
70
+ # Return the current prediction text for display in the UI
71
+ return current_prediction_text
72
+
73
+ # --- Gradio Interface ---
74
+ # Create a streaming input for webcam
75
+ webcam_input = gr.Image(
76
+ sources=["webcam"], # Allows webcam input
77
+ streaming=True, # Enables continuous streaming of frames
78
+ shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution
79
+ label="Live Webcam Feed"
80
+ )
81
+
82
+ # Output text box for predictions
83
+ prediction_output = gr.Textbox(label="Real-time Prediction")
84
+
85
+
86
+ # Define the Gradio Interface
87
+ # We use Blocks for more control over layout if needed, but Interface works too.
88
+ # For simplicity, we'll stick to a basic Interface
89
+ # For streaming, gr.Interface.load() is more common, but let's define from scratch.
90
+
91
+ demo = gr.Interface(
92
+ fn=predict_video_frame,
93
+ inputs=webcam_input,
94
+ outputs=prediction_output,
95
+ live=True, # Enable live updates
96
+ allow_flagging="never", # Disable flagging on public demo
97
+ title="TimesFormer Crime Detection Live Demo",
98
+ description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
99
+ # You might want to add examples for file uploads if you also want to support video files.
100
+ # examples=["path/to/your/test_video.mp4"] # If you add video upload input
101
+ )
102
+
103
+ if __name__ == "__main__":
104
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ opencv-python-headless # Use headless for server environments without display
4
+ Pillow
5
+ gradio