owinymarvin commited on
Commit
9d0ee1c
·
1 Parent(s): 7cc8973

latest changes

Browse files
Files changed (2) hide show
  1. app copy.py +101 -0
  2. app.py +84 -53
app copy.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, TimesformerForVideoClassification
4
+ import cv2
5
+ from PIL import Image
6
+ import numpy as np
7
+ import time
8
+ from collections import deque
9
+
10
+ # --- Configuration ---
11
+ # Your Hugging Face model repository ID
12
+ HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
+
14
+ # These must match the values used during your training
15
+ NUM_FRAMES = 8
16
+ TARGET_IMAGE_HEIGHT = 224
17
+ TARGET_IMAGE_WIDTH = 224
18
+
19
+ # --- Load Model and Processor ---
20
+ print(f"Loading model and image processor from {HF_MODEL_REPO_ID}...")
21
+ try:
22
+ processor = AutoImageProcessor.from_pretrained(HF_MODEL_REPO_ID)
23
+ model = TimesformerForVideoClassification.from_pretrained(HF_MODEL_REPO_ID)
24
+ except Exception as e:
25
+ print(f"Error loading model from Hugging Face Hub: {e}")
26
+ # Handle error - exit or raise exception for Space to fail gracefully
27
+ exit()
28
+
29
+ model.eval() # Set model to evaluation mode
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.to(device)
32
+ print(f"Model loaded successfully on {device}.")
33
+ print(f"Model's class labels: {model.config.id2label}")
34
+
35
+ # Initialize a global buffer for frames for the session
36
+ # Use a deque for efficient appending/popping from both ends
37
+ frame_buffer = deque(maxlen=NUM_FRAMES)
38
+ last_inference_time = time.time()
39
+ inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
40
+ current_prediction_text = "Buffering frames..." # Initialize global text
41
+
42
+ def predict_video_frame(image_np_array):
43
+ global frame_buffer, last_inference_time, current_prediction_text
44
+
45
+ # Gradio sends frames as numpy arrays (RGB)
46
+ # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
47
+ pil_image = Image.fromarray(image_np_array)
48
+ frame_buffer.append(pil_image)
49
+
50
+ current_time = time.time()
51
+
52
+ # Only perform inference if we have enough frames and it's time for a prediction
53
+ if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
54
+ last_inference_time = current_time
55
+
56
+ # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
57
+ # It will handle resizing and normalization based on its config
58
+ processed_input = processor(images=list(frame_buffer), return_tensors="pt")
59
+ pixel_values = processed_input.pixel_values.to(device)
60
+
61
+ with torch.no_grad():
62
+ outputs = model(pixel_values)
63
+ logits = outputs.logits
64
+
65
+ predicted_class_id = logits.argmax(-1).item()
66
+ predicted_label = model.config.id2label[predicted_class_id]
67
+ confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
68
+
69
+ current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
70
+ print(current_prediction_text) # Print to Space logs
71
+
72
+ # Return the current prediction text for display in the UI
73
+ # Gradio's streaming will update this textbox asynchronously
74
+ return current_prediction_text
75
+
76
+ # --- Gradio Interface ---
77
+ # Create a streaming input for webcam
78
+ webcam_input = gr.Image(
79
+ sources=["webcam"], # Allows webcam input
80
+ streaming=True, # Enables continuous streaming of frames
81
+ # REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError
82
+ label="Live Webcam Feed"
83
+ )
84
+
85
+ # Output text box for predictions
86
+ prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...")
87
+
88
+
89
+ # Define the Gradio Interface
90
+ demo = gr.Interface(
91
+ fn=predict_video_frame,
92
+ inputs=webcam_input,
93
+ outputs=prediction_output,
94
+ live=True, # Enable live updates
95
+ allow_flagging="never", # Disable flagging on public demo
96
+ title="TimesFormer Crime Detection Live Demo",
97
+ description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()
app.py CHANGED
@@ -12,7 +12,7 @@ from collections import deque
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
- NUM_FRAMES = 8
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
@@ -32,70 +32,101 @@ model.to(device)
32
  print(f"Model loaded successfully on {device}.")
33
  print(f"Model's class labels: {model.config.id2label}")
34
 
35
- # Initialize a global buffer for frames for the session
36
- # Use a deque for efficient appending/popping from both ends
37
- frame_buffer = deque(maxlen=NUM_FRAMES)
38
- last_inference_time = time.time()
39
- inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
40
- current_prediction_text = "Buffering frames..." # Initialize global text
41
 
42
- def predict_video_frame(image_np_array):
43
- global frame_buffer, last_inference_time, current_prediction_text
44
 
45
- # Gradio sends frames as numpy arrays (RGB)
46
- # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
 
 
 
47
  pil_image = Image.fromarray(image_np_array)
48
- frame_buffer.append(pil_image)
 
 
 
 
49
 
50
- current_time = time.time()
 
 
51
 
52
- # Only perform inference if we have enough frames and it's time for a prediction
53
- if len(frame_buffer) == NUM_FRAMES and (current_time - last_inference_time) >= inference_interval:
54
- last_inference_time = current_time
55
 
56
- # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
57
- # It will handle resizing and normalization based on its config
58
- processed_input = processor(images=list(frame_buffer), return_tensors="pt")
59
- pixel_values = processed_input.pixel_values.to(device)
60
 
61
- with torch.no_grad():
62
- outputs = model(pixel_values)
63
- logits = outputs.logits
 
64
 
65
- predicted_class_id = logits.argmax(-1).item()
66
- predicted_label = model.config.id2label[predicted_class_id]
67
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
68
 
69
- current_prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
70
- print(current_prediction_text) # Print to Space logs
 
 
 
 
 
 
 
 
71
 
72
- # Return the current prediction text for display in the UI
73
- # Gradio's streaming will update this textbox asynchronously
74
- return current_prediction_text
 
 
 
 
 
 
75
 
76
  # --- Gradio Interface ---
77
- # Create a streaming input for webcam
78
- webcam_input = gr.Image(
79
- sources=["webcam"], # Allows webcam input
80
- streaming=True, # Enables continuous streaming of frames
81
- # REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError
82
- label="Live Webcam Feed"
83
- )
84
-
85
- # Output text box for predictions
86
- prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...")
87
-
88
-
89
- # Define the Gradio Interface
90
- demo = gr.Interface(
91
- fn=predict_video_frame,
92
- inputs=webcam_input,
93
- outputs=prediction_output,
94
- live=True, # Enable live updates
95
- allow_flagging="never", # Disable flagging on public demo
96
- title="TimesFormer Crime Detection Live Demo",
97
- description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
98
- )
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch()
 
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
+ NUM_FRAMES = 16 # Still expecting 16 frames for a batch
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
 
32
  print(f"Model loaded successfully on {device}.")
33
  print(f"Model's class labels: {model.config.id2label}")
34
 
35
+ # Initialize a global buffer for frames that the webcam continuously captures
36
+ # This buffer will hold the *latest* NUM_FRAMES.
37
+ # We use a global variable to persist state across Gradio calls.
38
+ captured_frames_buffer = deque(maxlen=NUM_FRAMES)
 
 
39
 
40
+ # This flag will control the 5-minute wait (if still needed for testing)
41
+ wait_duration_seconds = 300 # 5 minutes
42
 
43
+ # --- Function to continuously capture frames (without immediate processing) ---
44
+ def capture_frame_into_buffer(image_np_array):
45
+ global captured_frames_buffer
46
+
47
+ # Convert Gradio's numpy array (RGB) to PIL Image
48
  pil_image = Image.fromarray(image_np_array)
49
+ captured_frames_buffer.append(pil_image)
50
+
51
+ # Return a message showing how many frames are buffered
52
+ return f"Frames buffered: {len(captured_frames_buffer)}/{NUM_FRAMES}"
53
+
54
 
55
+ # --- Function to trigger prediction with the buffered frames ---
56
+ def make_prediction_from_buffer():
57
+ global captured_frames_buffer
58
 
59
+ if len(captured_frames_buffer) < NUM_FRAMES:
60
+ return "Not enough frames buffered yet. Please capture more frames."
 
61
 
62
+ # Take a snapshot of the current frames in the buffer for prediction
63
+ # Convert deque to a list for the processor
64
+ frames_for_prediction = list(captured_frames_buffer)
 
65
 
66
+ # --- Perform Inference ---
67
+ print(f"Triggered inference on {len(frames_for_prediction)} frames...")
68
+ processed_input = processor(images=frames_for_prediction, return_tensors="pt")
69
+ pixel_values = processed_input.pixel_values.to(device)
70
 
71
+ with torch.no_grad():
72
+ outputs = model(pixel_values)
73
+ logits = outputs.logits
74
 
75
+ predicted_class_id = logits.argmax(-1).item()
76
+ predicted_label = model.config.id2label[predicted_class_id]
77
+ confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
78
+
79
+ prediction_text = f"Predicted: {predicted_label} ({confidence:.2f})"
80
+ print(prediction_text) # Print to Space logs
81
+
82
+ # Clear the buffer after prediction if you want to capture a *new* set of frames for the next click
83
+ # captured_frames_buffer.clear()
84
+ # If you *don't* clear, the next click will re-predict on the same last 16 frames.
85
 
86
+ # --- Introduce the artificial 5-minute wait (if still desired) ---
87
+ # This will pause the *return* from this function, effectively blocking the UI update
88
+ # If you remove this, the prediction will show immediately.
89
+ # print(f"Initiating {wait_duration_seconds} second wait...")
90
+ # time.sleep(wait_duration_seconds)
91
+ # print("Wait finished.")
92
+
93
+ return prediction_text
94
+
95
 
96
  # --- Gradio Interface ---
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown(
99
+ f"""
100
+ # TimesFormer Crime Detection Live Demo (Manual Trigger)
101
+ This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed.
102
+ It continuously buffers frames, but **only makes a prediction when you click the 'Predict' button**.
103
+ The model requires **{NUM_FRAMES} frames** for a prediction.
104
+ Please allow webcam access.
105
+ """
106
+ )
107
+ with gr.Row():
108
+ with gr.Column():
109
+ webcam_input = gr.Image(
110
+ sources=["webcam"],
111
+ streaming=True,
112
+ label="Live Webcam Feed"
113
+ )
114
+ # This textbox will show the buffering status dynamically
115
+ buffer_status = gr.Textbox(label="Frame Buffer Status", value=f"Frames buffered: 0/{NUM_FRAMES}")
116
+
117
+ # Button to trigger prediction
118
+ predict_button = gr.Button("Predict Latest Frames")
119
+
120
+ with gr.Column():
121
+ prediction_output = gr.Textbox(label="Prediction Result", value="Click 'Predict Latest Frames' to start.")
122
+
123
+ # Define actions
124
+ # This continuously updates the buffer_status as frames come in
125
+ webcam_input.stream(capture_frame_into_buffer, inputs=[webcam_input], outputs=[buffer_status])
126
+
127
+ # This triggers the prediction when the button is clicked
128
+ predict_button.click(make_prediction_from_buffer, inputs=[], outputs=[prediction_output])
129
+
130
 
131
  if __name__ == "__main__":
132
  demo.launch()