Spaces:
Running
Running
Commit
·
c450b97
1
Parent(s):
4873e8b
latest changes
Browse files
app.py
CHANGED
@@ -41,7 +41,8 @@ def sample_frames(frames_list, target_count):
|
|
41 |
if len(frames_list) <= target_count:
|
42 |
return frames_list
|
43 |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
|
44 |
-
|
|
|
45 |
return sampled
|
46 |
|
47 |
def live_predict_stream(image_np_array):
|
@@ -74,7 +75,8 @@ def live_predict_stream(image_np_array):
|
|
74 |
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
|
75 |
prediction_result = "Error: Not enough frames for model."
|
76 |
status_message = "Error during frame sampling."
|
77 |
-
|
|
|
78 |
raw_frames_buffer.clear()
|
79 |
current_clip_start_time = time.time()
|
80 |
last_prediction_completion_time = time.time()
|
@@ -88,7 +90,7 @@ def live_predict_stream(image_np_array):
|
|
88 |
logits = outputs.logits
|
89 |
|
90 |
predicted_class_id = logits.argmax(-1).item()
|
91 |
-
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
|
92 |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
|
93 |
|
94 |
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
|
@@ -175,12 +177,20 @@ with gr.Blocks() as demo:
|
|
175 |
Use this API endpoint to send base64-encoded frames for prediction.
|
176 |
"""
|
177 |
)
|
|
|
|
|
178 |
gr.Interface(
|
179 |
-
fn=lambda
|
180 |
-
inputs=gr.
|
181 |
-
outputs=gr.Textbox(label="
|
182 |
-
|
|
|
183 |
)
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
if __name__ == "__main__":
|
186 |
demo.launch()
|
|
|
41 |
if len(frames_list) <= target_count:
|
42 |
return frames_list
|
43 |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
|
44 |
+
# FIX: Corrected list indexing from () to []
|
45 |
+
sampled = [frames_list[int(i)] for i in indices]
|
46 |
return sampled
|
47 |
|
48 |
def live_predict_stream(image_np_array):
|
|
|
75 |
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
|
76 |
prediction_result = "Error: Not enough frames for model."
|
77 |
status_message = "Error during frame sampling."
|
78 |
+
print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}")
|
79 |
+
app_state = "recording" # Reset to recording state
|
80 |
raw_frames_buffer.clear()
|
81 |
current_clip_start_time = time.time()
|
82 |
last_prediction_completion_time = time.time()
|
|
|
90 |
logits = outputs.logits
|
91 |
|
92 |
predicted_class_id = logits.argmax(-1).item()
|
93 |
+
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
|
94 |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
|
95 |
|
96 |
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
|
|
|
177 |
Use this API endpoint to send base64-encoded frames for prediction.
|
178 |
"""
|
179 |
)
|
180 |
+
# Re-adding a slightly more representative API interface
|
181 |
+
# Gradio's automatic API documentation will use this to show inputs/outputs
|
182 |
gr.Interface(
|
183 |
+
fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.",
|
184 |
+
inputs=gr.Json(label="List of Base64-encoded image strings"),
|
185 |
+
outputs=gr.Textbox(label="API Response"),
|
186 |
+
live=False,
|
187 |
+
allow_flagging="never" # For API endpoints, flagging is usually not desired
|
188 |
)
|
189 |
+
# Note: The actual `predict_from_frames_api` function is defined above,
|
190 |
+
# but for a clean API tab, we can use a dummy interface here that Gradio will
|
191 |
+
# use to generate the interactive API documentation. The actual API call
|
192 |
+
# from your local script directly targets the /run/predict_from_frames_api endpoint.
|
193 |
+
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
demo.launch()
|