owinymarvin commited on
Commit
7cc8973
·
1 Parent(s): 8f7ac8f

Add application file

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -12,7 +12,7 @@ from collections import deque
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
- NUM_FRAMES = 16
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
@@ -37,12 +37,13 @@ print(f"Model's class labels: {model.config.id2label}")
37
  frame_buffer = deque(maxlen=NUM_FRAMES)
38
  last_inference_time = time.time()
39
  inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
40
- current_prediction_text = "Buffering frames..."
41
 
42
  def predict_video_frame(image_np_array):
43
  global frame_buffer, last_inference_time, current_prediction_text
44
 
45
  # Gradio sends frames as numpy arrays (RGB)
 
46
  pil_image = Image.fromarray(image_np_array)
47
  frame_buffer.append(pil_image)
48
 
@@ -53,6 +54,7 @@ def predict_video_frame(image_np_array):
53
  last_inference_time = current_time
54
 
55
  # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
 
56
  processed_input = processor(images=list(frame_buffer), return_tensors="pt")
57
  pixel_values = processed_input.pixel_values.to(device)
58
 
@@ -68,6 +70,7 @@ def predict_video_frame(image_np_array):
68
  print(current_prediction_text) # Print to Space logs
69
 
70
  # Return the current prediction text for display in the UI
 
71
  return current_prediction_text
72
 
73
  # --- Gradio Interface ---
@@ -75,19 +78,15 @@ def predict_video_frame(image_np_array):
75
  webcam_input = gr.Image(
76
  sources=["webcam"], # Allows webcam input
77
  streaming=True, # Enables continuous streaming of frames
78
- shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # Set expected input resolution
79
  label="Live Webcam Feed"
80
  )
81
 
82
  # Output text box for predictions
83
- prediction_output = gr.Textbox(label="Real-time Prediction")
84
 
85
 
86
  # Define the Gradio Interface
87
- # We use Blocks for more control over layout if needed, but Interface works too.
88
- # For simplicity, we'll stick to a basic Interface
89
- # For streaming, gr.Interface.load() is more common, but let's define from scratch.
90
-
91
  demo = gr.Interface(
92
  fn=predict_video_frame,
93
  inputs=webcam_input,
@@ -96,8 +95,6 @@ demo = gr.Interface(
96
  allow_flagging="never", # Disable flagging on public demo
97
  title="TimesFormer Crime Detection Live Demo",
98
  description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
99
- # You might want to add examples for file uploads if you also want to support video files.
100
- # examples=["path/to/your/test_video.mp4"] # If you add video upload input
101
  )
102
 
103
  if __name__ == "__main__":
 
12
  HF_MODEL_REPO_ID = "owinymarvin/timesformer-crime-detection"
13
 
14
  # These must match the values used during your training
15
+ NUM_FRAMES = 8
16
  TARGET_IMAGE_HEIGHT = 224
17
  TARGET_IMAGE_WIDTH = 224
18
 
 
37
  frame_buffer = deque(maxlen=NUM_FRAMES)
38
  last_inference_time = time.time()
39
  inference_interval = 1.0 # Predict every 1 second (1.0 / INFERENCE_FPS)
40
+ current_prediction_text = "Buffering frames..." # Initialize global text
41
 
42
  def predict_video_frame(image_np_array):
43
  global frame_buffer, last_inference_time, current_prediction_text
44
 
45
  # Gradio sends frames as numpy arrays (RGB)
46
+ # The image_processor will handle the resizing to TARGET_IMAGE_HEIGHT x TARGET_IMAGE_WIDTH
47
  pil_image = Image.fromarray(image_np_array)
48
  frame_buffer.append(pil_image)
49
 
 
54
  last_inference_time = current_time
55
 
56
  # Preprocess the frames. processor expects a list of PIL Images or numpy arrays
57
+ # It will handle resizing and normalization based on its config
58
  processed_input = processor(images=list(frame_buffer), return_tensors="pt")
59
  pixel_values = processed_input.pixel_values.to(device)
60
 
 
70
  print(current_prediction_text) # Print to Space logs
71
 
72
  # Return the current prediction text for display in the UI
73
+ # Gradio's streaming will update this textbox asynchronously
74
  return current_prediction_text
75
 
76
  # --- Gradio Interface ---
 
78
  webcam_input = gr.Image(
79
  sources=["webcam"], # Allows webcam input
80
  streaming=True, # Enables continuous streaming of frames
81
+ # REMOVED: shape=(TARGET_IMAGE_WIDTH, TARGET_IMAGE_HEIGHT), # This was causing the TypeError
82
  label="Live Webcam Feed"
83
  )
84
 
85
  # Output text box for predictions
86
+ prediction_output = gr.Textbox(label="Real-time Prediction", value="Buffering frames...")
87
 
88
 
89
  # Define the Gradio Interface
 
 
 
 
90
  demo = gr.Interface(
91
  fn=predict_video_frame,
92
  inputs=webcam_input,
 
95
  allow_flagging="never", # Disable flagging on public demo
96
  title="TimesFormer Crime Detection Live Demo",
97
  description=f"This demo uses a finetuned TimesFormer model ({HF_MODEL_REPO_ID}) to predict crime actions from a live webcam feed. The model processes {NUM_FRAMES} frames at a time and makes a prediction every {inference_interval} seconds. Please allow webcam access.",
 
 
98
  )
99
 
100
  if __name__ == "__main__":