owinymarvin commited on
Commit
c450b97
·
1 Parent(s): 4873e8b

latest changes

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -41,7 +41,8 @@ def sample_frames(frames_list, target_count):
41
  if len(frames_list) <= target_count:
42
  return frames_list
43
  indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
44
- sampled = [frames_list(int(i)) for i in indices] # Corrected sampling
 
45
  return sampled
46
 
47
  def live_predict_stream(image_np_array):
@@ -74,7 +75,8 @@ def live_predict_stream(image_np_array):
74
  if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
75
  prediction_result = "Error: Not enough frames for model."
76
  status_message = "Error during frame sampling."
77
- app_state = "recording"
 
78
  raw_frames_buffer.clear()
79
  current_clip_start_time = time.time()
80
  last_prediction_completion_time = time.time()
@@ -88,7 +90,7 @@ def live_predict_stream(image_np_array):
88
  logits = outputs.logits
89
 
90
  predicted_class_id = logits.argmax(-1).item()
91
- predicted_label = model.config.id2label.get(predicted_class_id, "Unknown") # Handle potential missing label
92
  confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
93
 
94
  prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
@@ -175,12 +177,20 @@ with gr.Blocks() as demo:
175
  Use this API endpoint to send base64-encoded frames for prediction.
176
  """
177
  )
 
 
178
  gr.Interface(
179
- fn=lambda x: "API endpoint is active",
180
- inputs=gr.Textbox(label="Input (Base64 JSON)"),
181
- outputs=gr.Textbox(label="Status"),
182
- title="API Status (Details in app.py)" # Minimal UI for API tab
 
183
  )
 
 
 
 
 
184
 
185
  if __name__ == "__main__":
186
  demo.launch()
 
41
  if len(frames_list) <= target_count:
42
  return frames_list
43
  indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
44
+ # FIX: Corrected list indexing from () to []
45
+ sampled = [frames_list[int(i)] for i in indices]
46
  return sampled
47
 
48
  def live_predict_stream(image_np_array):
 
75
  if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
76
  prediction_result = "Error: Not enough frames for model."
77
  status_message = "Error during frame sampling."
78
+ print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}")
79
+ app_state = "recording" # Reset to recording state
80
  raw_frames_buffer.clear()
81
  current_clip_start_time = time.time()
82
  last_prediction_completion_time = time.time()
 
90
  logits = outputs.logits
91
 
92
  predicted_class_id = logits.argmax(-1).item()
93
+ predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
94
  confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
95
 
96
  prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
 
177
  Use this API endpoint to send base64-encoded frames for prediction.
178
  """
179
  )
180
+ # Re-adding a slightly more representative API interface
181
+ # Gradio's automatic API documentation will use this to show inputs/outputs
182
  gr.Interface(
183
+ fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.",
184
+ inputs=gr.Json(label="List of Base64-encoded image strings"),
185
+ outputs=gr.Textbox(label="API Response"),
186
+ live=False,
187
+ allow_flagging="never" # For API endpoints, flagging is usually not desired
188
  )
189
+ # Note: The actual `predict_from_frames_api` function is defined above,
190
+ # but for a clean API tab, we can use a dummy interface here that Gradio will
191
+ # use to generate the interactive API documentation. The actual API call
192
+ # from your local script directly targets the /run/predict_from_frames_api endpoint.
193
+
194
 
195
  if __name__ == "__main__":
196
  demo.launch()