owinymarvin commited on
Commit
4ad907b
·
1 Parent(s): e1f8629

latest changes

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -15,17 +15,15 @@ class SmallVideoClassifier(torch.nn.Module):
15
  def __init__(self, num_classes=2, num_frames=8):
16
  super(SmallVideoClassifier, self).__init__()
17
  from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
18
- # Load weights only if they are available (they should be for IMAGENET1K_V1)
19
- # Use a check to prevent error if weights are not found in specific environments
20
  try:
21
  weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
22
  except Exception:
23
  print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
24
- weights = None # Or provide specific default for your use case
25
 
26
  self.feature_extractor = mobilenet_v3_small(weights=weights)
27
  self.feature_extractor.classifier = torch.nn.Identity()
28
- self.num_spatial_features = 576 # MobileNetV3-Small's final feature map size
29
  self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
30
  self.classifier = torch.nn.Sequential(
31
  torch.nn.Linear(self.num_spatial_features, 512),
@@ -48,7 +46,6 @@ HF_USERNAME = "owinymarvin"
48
  NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
49
  NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
50
 
51
- # Download config.json to get model parameters
52
  print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
53
  config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
54
  with open(config_path, 'r') as f:
@@ -58,24 +55,20 @@ NUM_FRAMES = model_config.get('num_frames', 8)
58
  IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
59
  NUM_CLASSES = model_config.get('num_classes', 2)
60
 
61
- # Define class labels (adjust if your dataset had different labels/order)
62
  CLASS_LABELS = ["Non-violence", "Violence"]
63
  if NUM_CLASSES != len(CLASS_LABELS):
64
  print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
65
 
66
-
67
- # Initialize the model
68
- device = torch.device("cpu") # Explicitly use CPU as requested
69
  print(f"Using device: {device}")
70
 
71
  model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
72
 
73
- # Download model weights
74
  print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
75
  model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
76
  model.load_state_dict(torch.load(model_weights_path, map_location=device))
77
  model.to(device)
78
- model.eval() # Set model to evaluation mode
79
 
80
  print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
81
 
@@ -89,7 +82,7 @@ transform = transforms.Compose([
89
  # --- 4. Gradio Inference Function ---
90
  def predict_video(video_path):
91
  if video_path is None:
92
- return None # Or raise an error, or return a placeholder video
93
 
94
  cap = cv2.VideoCapture(video_path)
95
 
@@ -97,50 +90,47 @@ def predict_video(video_path):
97
  print(f"Error: Could not open video file {video_path}.")
98
  raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")
99
 
100
- # Get video properties
101
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
 
 
 
104
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
105
 
106
- # Create a temporary output video file
107
- # Use tempfile to ensure proper cleanup on Hugging Face Spaces
108
  temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
109
  output_video_path = temp_output_file.name
110
- temp_output_file.close() # Close the file handle as cv2.VideoWriter needs to open it
111
 
112
- # Define the codec and create VideoWriter object
113
- # For MP4, 'mp4v' is generally compatible.
114
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
115
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
116
 
117
  print(f"Processing video: {video_path}")
118
  print(f"Total frames: {total_frames}, FPS: {fps}")
119
  print(f"Output video will be saved to: {output_video_path}")
120
 
121
- frame_buffer = [] # To store NUM_FRAMES for each prediction batch
122
- current_prediction_label = "Processing..." # Initial label
123
 
124
  frame_idx = 0
125
  while True:
126
  ret, frame = cap.read()
127
  if not ret:
128
- break # End of video
129
 
130
  frame_idx += 1
131
 
132
- # Convert frame from BGR (OpenCV) to RGB (PIL/PyTorch)
133
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
134
  pil_image = Image.fromarray(frame_rgb)
135
 
136
- # Apply transformations and add to buffer
137
- processed_frame = transform(pil_image) # shape: (C, H, W)
138
  frame_buffer.append(processed_frame)
139
 
140
- # Perform prediction when the buffer is full
141
  if len(frame_buffer) == NUM_FRAMES:
142
- # Stack the buffered frames and add a batch dimension
143
- # Resulting shape: (1, NUM_FRAMES, C, H, W)
144
  input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
145
 
146
  with torch.no_grad():
@@ -149,39 +139,45 @@ def predict_video(video_path):
149
  predicted_class_idx = torch.argmax(probabilities, dim=1).item()
150
  current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
151
 
152
- # Reset buffer for the next non-overlapping window
153
  frame_buffer = []
154
- # Or, if you want sliding window (more continuous output but higher compute):
155
- # frame_buffer = frame_buffer[1:] # e.g., slide by 1 frame
156
 
157
  # Draw prediction text on the current frame
158
- # The prediction will lag by NUM_FRAMES, as it's based on the previous batch.
159
- # We display the last known prediction.
160
- cv2.putText(frame, current_prediction_label, (10, 30),
161
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
162
 
163
- # Write the processed frame to the output video
164
  out.write(frame)
165
 
166
- # Release resources
167
  cap.release()
168
  out.release()
169
  print(f"Video processing complete. Output saved to: {output_video_path}")
170
 
171
- return output_video_path # Gradio expects the path to the output video
172
 
173
  # --- 5. Gradio Interface Setup ---
174
  iface = gr.Interface(
175
  fn=predict_video,
176
- # Corrected: Removed 'type="filepath"' as it's not a valid argument for gr.Video
177
  inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
178
  outputs=gr.Video(label="Processed Video with Predictions"),
179
  title="Real-time Violence Detection with SmallVideoClassifier",
180
  description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
181
- allow_flagging="never", # Disable flagging on Hugging Face Spaces
182
  examples=[
183
- # You can provide example video URLs or paths if you have them publicly available
184
- # e.g., "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
 
185
  ]
186
  )
187
 
 
15
  def __init__(self, num_classes=2, num_frames=8):
16
  super(SmallVideoClassifier, self).__init__()
17
  from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
 
 
18
  try:
19
  weights = MobileNet_V3_Small_Weights.IMAGENET1K_V1
20
  except Exception:
21
  print("Warning: MobileNet_V3_Small_Weights.IMAGENET1K_V1 not found, initializing without pre-trained weights.")
22
+ weights = None
23
 
24
  self.feature_extractor = mobilenet_v3_small(weights=weights)
25
  self.feature_extractor.classifier = torch.nn.Identity()
26
+ self.num_spatial_features = 576
27
  self.temporal_aggregator = torch.nn.AdaptiveAvgPool1d(1)
28
  self.classifier = torch.nn.Sequential(
29
  torch.nn.Linear(self.num_spatial_features, 512),
 
46
  NEW_MODEL_REPO_ID_SHORT = "timesformer-violence-detector"
47
  NEW_MODEL_REPO_ID = f"{HF_USERNAME}/{NEW_MODEL_REPO_ID_SHORT}"
48
 
 
49
  print(f"Downloading config.json from {NEW_MODEL_REPO_ID}...")
50
  config_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="config.json")
51
  with open(config_path, 'r') as f:
 
55
  IMAGE_SIZE = tuple(model_config.get('image_size', [224, 224]))
56
  NUM_CLASSES = model_config.get('num_classes', 2)
57
 
 
58
  CLASS_LABELS = ["Non-violence", "Violence"]
59
  if NUM_CLASSES != len(CLASS_LABELS):
60
  print(f"Warning: NUM_CLASSES in config ({NUM_CLASSES}) does not match hardcoded CLASS_LABELS length ({len(CLASS_LABELS)}). Adjust CLASS_LABELS if needed.")
61
 
62
+ device = torch.device("cpu")
 
 
63
  print(f"Using device: {device}")
64
 
65
  model = SmallVideoClassifier(num_classes=NUM_CLASSES, num_frames=NUM_FRAMES)
66
 
 
67
  print(f"Downloading model weights from {NEW_MODEL_REPO_ID}...")
68
  model_weights_path = hf_hub_download(repo_id=NEW_MODEL_REPO_ID, filename="small_violence_classifier.pth")
69
  model.load_state_dict(torch.load(model_weights_path, map_location=device))
70
  model.to(device)
71
+ model.eval()
72
 
73
  print(f"Model loaded successfully with {NUM_FRAMES} frames and image size {IMAGE_SIZE}.")
74
 
 
82
  # --- 4. Gradio Inference Function ---
83
  def predict_video(video_path):
84
  if video_path is None:
85
+ return None
86
 
87
  cap = cv2.VideoCapture(video_path)
88
 
 
90
  print(f"Error: Could not open video file {video_path}.")
91
  raise ValueError(f"Could not open video file {video_path}. Please ensure it's a valid video format.")
92
 
 
93
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
  fps = cap.get(cv2.CAP_PROP_FPS)
96
+ # Ensure FPS is not zero to avoid division by zero errors, default to 25 if needed
97
+ if fps <= 0:
98
+ fps = 25.0
99
+ print(f"Warning: Original video FPS was 0 or less, defaulting to {fps}.")
100
+
101
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
102
 
 
 
103
  temp_output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
104
  output_video_path = temp_output_file.name
105
+ temp_output_file.close()
106
 
107
+ # --- CHANGED: Use XVID codec for better browser compatibility ---
108
+ # This might prevent Gradio's internal re-encoding.
109
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
110
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
111
 
112
  print(f"Processing video: {video_path}")
113
  print(f"Total frames: {total_frames}, FPS: {fps}")
114
  print(f"Output video will be saved to: {output_video_path}")
115
 
116
+ frame_buffer = []
117
+ current_prediction_label = "Processing..."
118
 
119
  frame_idx = 0
120
  while True:
121
  ret, frame = cap.read()
122
  if not ret:
123
+ break
124
 
125
  frame_idx += 1
126
 
 
127
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
128
  pil_image = Image.fromarray(frame_rgb)
129
 
130
+ processed_frame = transform(pil_image)
 
131
  frame_buffer.append(processed_frame)
132
 
 
133
  if len(frame_buffer) == NUM_FRAMES:
 
 
134
  input_tensor = torch.stack(frame_buffer, dim=0).unsqueeze(0).to(device)
135
 
136
  with torch.no_grad():
 
139
  predicted_class_idx = torch.argmax(probabilities, dim=1).item()
140
  current_prediction_label = f"Prediction: {CLASS_LABELS[predicted_class_idx]} (Prob: {probabilities[0, predicted_class_idx]:.2f})"
141
 
 
142
  frame_buffer = []
143
+ # If you want a sliding window, you would do something like:
144
+ # frame_buffer = frame_buffer[int(NUM_FRAMES * 0.5):] # Slide by half the window size
145
 
146
  # Draw prediction text on the current frame
147
+ # Ensure text color is clearly visible (e.g., white or bright green)
148
+ # Add a black outline for better readability
149
+ text_color = (0, 255, 0) # Green (BGR format for OpenCV)
150
+ text_outline_color = (0, 0, 0) # Black
151
+ font_scale = 1.0 # Increased font size
152
+ font_thickness = 2
153
+
154
+ # Draw outline first for better readability
155
+ cv2.putText(frame, current_prediction_label, (10, 40), # Slightly lower position
156
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_outline_color, font_thickness + 2, cv2.LINE_AA)
157
+ # Draw actual text
158
+ cv2.putText(frame, current_prediction_label, (10, 40),
159
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness, cv2.LINE_AA)
160
 
 
161
  out.write(frame)
162
 
 
163
  cap.release()
164
  out.release()
165
  print(f"Video processing complete. Output saved to: {output_video_path}")
166
 
167
+ return output_video_path
168
 
169
  # --- 5. Gradio Interface Setup ---
170
  iface = gr.Interface(
171
  fn=predict_video,
 
172
  inputs=gr.Video(label="Upload Video for Violence Detection (MP4 recommended)"),
173
  outputs=gr.Video(label="Processed Video with Predictions"),
174
  title="Real-time Violence Detection with SmallVideoClassifier",
175
  description="Upload a video, and the model will analyze it for violence, displaying the predicted class and confidence on each frame.",
176
+ allow_flagging="never",
177
  examples=[
178
+ # Add example videos here for easier testing and demonstration
179
+ # E.g., a sample video that's publicly accessible:
180
+ # "https://huggingface.co/datasets/gradio/test-files/resolve/main/video.mp4"
181
  ]
182
  )
183