fffiloni commited on
Commit
c80c42f
·
verified ·
1 Parent(s): 446c7ea

Add timeout to prediction polling to prevent queue lock

Browse files

Previously, the Gradio app would poll the Cog prediction status indefinitely, waiting for a "succeeded" or "failed" response. If the Cog server became unresponsive or a prediction hung (e.g. due to a bad input or internal model error), the polling loop would never exit.

This behavior caused the Gradio queue to back up and eventually stall completely under high load, since all workers could get stuck waiting for unresponsive predictions.

This change introduces a timeout mechanism (60 seconds) in the polling loop. If the prediction doesn't complete within that time, an error is raised and the request is gracefully terminated. This helps avoid queue deadlocks and improves reliability under concurrent usage.

Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -60,46 +60,58 @@ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
60
  headers = {'Content-Type': 'application/json'}
61
 
62
  payload = {"input": {}}
63
-
64
-
65
  base_url = "http://0.0.0.0:7860"
 
66
  for i, key in enumerate(names):
67
  value = args[i]
68
- if value and (os.path.exists(str(value))):
69
  value = f"{base_url}/gradio_api/file=" + value
70
  if value is not None and value != "":
71
  payload["input"][key] = value
72
 
73
- time.sleep(1.0)
74
  response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
75
-
76
-
77
  if response.status_code == 201:
78
- time.sleep(1.0)
79
  follow_up_url = response.json()["urls"]["get"]
80
- response = requests.get(follow_up_url, headers=headers)
81
- while response.json()["status"] != "succeeded":
82
- if response.json()["status"] == "failed":
83
- raise gr.Error("The submission failed!")
 
 
 
84
  response = requests.get(follow_up_url, headers=headers)
85
-
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if response.status_code == 200:
87
-
88
  json_response = response.json()
89
- #If the output component is JSON return the entire output response
90
- if(outputs[0].get_config()["name"] == "json"):
91
  return json_response["output"]
 
92
  predict_outputs = parse_outputs(json_response["output"])
93
  processed_outputs = process_outputs(predict_outputs)
94
  print(f"processed_outputs: {processed_outputs}")
95
  return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
96
  else:
97
- time.sleep(1)
98
- if(response.status_code == 409):
99
- raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
100
  raise gr.Error(f"The submission failed! Error: {response.status_code}")
101
 
102
 
 
103
  css = '''
104
  #col-container{max-width: 800px;margin: 0 auto;}
105
  '''
 
60
  headers = {'Content-Type': 'application/json'}
61
 
62
  payload = {"input": {}}
 
 
63
  base_url = "http://0.0.0.0:7860"
64
+
65
  for i, key in enumerate(names):
66
  value = args[i]
67
+ if value and os.path.exists(str(value)):
68
  value = f"{base_url}/gradio_api/file=" + value
69
  if value is not None and value != "":
70
  payload["input"][key] = value
71
 
72
+ time.sleep(1.0)
73
  response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
74
+
 
75
  if response.status_code == 201:
 
76
  follow_up_url = response.json()["urls"]["get"]
77
+
78
+ # Timeout logic
79
+ max_wait_seconds = 60
80
+ poll_interval = 1
81
+ start_time = time.time()
82
+
83
+ while True:
84
  response = requests.get(follow_up_url, headers=headers)
85
+ try:
86
+ response_json = response.json()
87
+ except ValueError:
88
+ raise gr.Error("Cog server response is not valid JSON.")
89
+
90
+ status = response_json.get("status")
91
+ if status == "succeeded":
92
+ break
93
+ if status == "failed":
94
+ raise gr.Error("The submission failed.")
95
+ if time.time() - start_time > max_wait_seconds:
96
+ raise gr.Error("Prediction timed out after 60 seconds.")
97
+ time.sleep(poll_interval)
98
+
99
  if response.status_code == 200:
 
100
  json_response = response.json()
101
+ if outputs[0].get_config()["name"] == "json":
 
102
  return json_response["output"]
103
+
104
  predict_outputs = parse_outputs(json_response["output"])
105
  processed_outputs = process_outputs(predict_outputs)
106
  print(f"processed_outputs: {processed_outputs}")
107
  return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
108
  else:
109
+ if response.status_code == 409:
110
+ raise gr.Error("Sorry, the Cog image is still processing. Try again in a bit.")
 
111
  raise gr.Error(f"The submission failed! Error: {response.status_code}")
112
 
113
 
114
+
115
  css = '''
116
  #col-container{max-width: 800px;margin: 0 auto;}
117
  '''