francismurray commited on
Commit
4667b7d
·
1 Parent(s): b6ed54a

feat: Implement async model responses with real-time progress

Browse files

- Make model calls run concurrently instead of sequentially
- Add real-time progress indicator with 0.1s precision timer
- Display responses immediately as they arrive
- Improve error handling and loading states

Files changed (1) hide show
  1. app.py +71 -23
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
  import gradio as gr
 
3
  from dotenv import load_dotenv
4
  from huggingface_hub import InferenceClient
 
5
 
6
  # Load environment variables
7
  load_dotenv()
@@ -22,7 +24,7 @@ AVAILABLE_MODELS = [
22
  # Initialize inference client
23
  inference_client = InferenceClient(token=HF_TOKEN)
24
 
25
- def get_model_response(prompt, model_name, temperature_value, do_sample):
26
  """Get response from a Hugging Face model."""
27
  try:
28
  # Build kwargs dynamically
@@ -38,31 +40,80 @@ def get_model_response(prompt, model_name, temperature_value, do_sample):
38
  if do_sample and temperature_value > 0:
39
  generation_args["temperature"] = temperature_value
40
 
41
- response = inference_client.text_generation(**generation_args)
 
 
 
 
 
42
  return response
43
 
44
  except Exception as e:
45
  return f"Error: {str(e)}"
46
 
47
- def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sample2):
 
 
 
 
 
 
48
  """Compare outputs from two selected models."""
49
  if not prompt.strip():
50
- return (
51
- [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}],
52
- [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}],
53
- gr.update(interactive=True)
54
- )
55
 
56
- response1 = get_model_response(prompt, model1, temp1, do_sample1)
57
- response2 = get_model_response(prompt, model2, temp2, do_sample2)
 
58
 
59
- # Format responses for chatbot display
60
- chat1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}]
61
- chat2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}]
62
-
63
 
64
- return chat1, chat2, gr.update(interactive=True)
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Update temperature slider interactivity based on sampling checkbox
68
  def update_slider_state(enabled):
@@ -79,7 +130,7 @@ with gr.Blocks(css="""
79
  .disabled-slider { opacity: 0.5; pointer-events: none; }
80
  """) as demo:
81
  gr.Markdown("# LLM Comparison Tool")
82
- gr.Markdown("Compare outputs from different Hugging Face models side by side.")
83
 
84
  with gr.Row():
85
  prompt = gr.Textbox(
@@ -117,7 +168,6 @@ with gr.Blocks(css="""
117
  height=300,
118
  type="messages"
119
  )
120
-
121
 
122
  with gr.Column():
123
  model2_dropdown = gr.Dropdown(
@@ -157,11 +207,10 @@ with gr.Blocks(css="""
157
  ).then(
158
  fn=compare_models,
159
  inputs=[prompt, model1_dropdown, model2_dropdown, temp1, temp2, do_sample1, do_sample2],
160
- outputs=[chatbot1, chatbot2, submit_btn]
 
161
  )
162
 
163
-
164
-
165
  do_sample1.change(
166
  fn=update_slider_state,
167
  inputs=[do_sample1],
@@ -175,5 +224,4 @@ with gr.Blocks(css="""
175
  )
176
 
177
  if __name__ == "__main__":
178
- demo.launch()
179
- # demo.launch(share=True)
 
1
  import os
2
  import gradio as gr
3
+ import asyncio
4
  from dotenv import load_dotenv
5
  from huggingface_hub import InferenceClient
6
+ from functools import partial
7
 
8
  # Load environment variables
9
  load_dotenv()
 
24
  # Initialize inference client
25
  inference_client = InferenceClient(token=HF_TOKEN)
26
 
27
+ async def get_model_response(prompt, model_name, temperature_value, do_sample):
28
  """Get response from a Hugging Face model."""
29
  try:
30
  # Build kwargs dynamically
 
40
  if do_sample and temperature_value > 0:
41
  generation_args["temperature"] = temperature_value
42
 
43
+ # Run the inference in a thread pool to not block the event loop
44
+ loop = asyncio.get_event_loop()
45
+ response = await loop.run_in_executor(
46
+ None,
47
+ partial(inference_client.text_generation, **generation_args)
48
+ )
49
  return response
50
 
51
  except Exception as e:
52
  return f"Error: {str(e)}"
53
 
54
+ async def process_single_response(prompt, model_name, temp, do_sample, chatbot):
55
+ """Process a single model response and update its chatbot."""
56
+ response = await get_model_response(prompt, model_name, temp, do_sample)
57
+ chat_history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
58
+ return chat_history
59
+
60
+ async def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sample2):
61
  """Compare outputs from two selected models."""
62
  if not prompt.strip():
63
+ empty_response = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}]
64
+ yield empty_response, empty_response, gr.update(interactive=True)
65
+ return # Exit the generator
 
 
66
 
67
+ # Initialize with "Generating..." messages
68
+ initial_message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Generating..."}]
69
+ yield initial_message, initial_message, gr.update(interactive=False)
70
 
71
+ # Create tasks for both model responses
72
+ task1 = asyncio.create_task(process_single_response(prompt, model1, temp1, do_sample1, "chatbot1"))
73
+ task2 = asyncio.create_task(process_single_response(prompt, model2, temp2, do_sample2, "chatbot2"))
 
74
 
75
+ chat1 = chat2 = initial_message
76
+ start_time = asyncio.get_event_loop().time()
77
+
78
+ try:
79
+ while not (task1.done() and task2.done()):
80
+ # Update the messages with elapsed time
81
+ elapsed = round(asyncio.get_event_loop().time() - start_time, 1)
82
+ chat1_content = chat1[1]["content"]
83
+ chat2_content = chat2[1]["content"]
84
+
85
+ if not task1.done():
86
+ chat1 = [{"role": "user", "content": prompt},
87
+ {"role": "assistant", "content": f"Generating... ({elapsed:.1f}s)"}]
88
+ if not task2.done():
89
+ chat2 = [{"role": "user", "content": prompt},
90
+ {"role": "assistant", "content": f"Generating... ({elapsed:.1f}s)"}]
91
+
92
+ # Check if any task completed
93
+ done, pending = await asyncio.wait([t for t in [task1, task2] if not t.done()],
94
+ timeout=0.1,
95
+ return_when=asyncio.FIRST_COMPLETED)
96
+
97
+ for task in done:
98
+ if task == task1:
99
+ chat1 = await task1
100
+ else:
101
+ chat2 = await task2
102
+
103
+ yield chat1, chat2, gr.update(interactive=False)
104
+
105
+ # Ensure we have both final results
106
+ if not task1.done():
107
+ chat1 = await task1
108
+ if not task2.done():
109
+ chat2 = await task2
110
+
111
+ # Final yield with both results
112
+ yield chat1, chat2, gr.update(interactive=True)
113
+
114
+ except Exception as e:
115
+ error_message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": f"Error: {str(e)}"}]
116
+ yield error_message, error_message, gr.update(interactive=True)
117
 
118
  # Update temperature slider interactivity based on sampling checkbox
119
  def update_slider_state(enabled):
 
130
  .disabled-slider { opacity: 0.5; pointer-events: none; }
131
  """) as demo:
132
  gr.Markdown("# LLM Comparison Tool")
133
+ gr.Markdown("Using HuggingFace's Inference API, compare outputs from different `text-generation` models side by side.")
134
 
135
  with gr.Row():
136
  prompt = gr.Textbox(
 
168
  height=300,
169
  type="messages"
170
  )
 
171
 
172
  with gr.Column():
173
  model2_dropdown = gr.Dropdown(
 
207
  ).then(
208
  fn=compare_models,
209
  inputs=[prompt, model1_dropdown, model2_dropdown, temp1, temp2, do_sample1, do_sample2],
210
+ outputs=[chatbot1, chatbot2, submit_btn],
211
+ queue=True # Enable queuing for streaming updates
212
  )
213
 
 
 
214
  do_sample1.change(
215
  fn=update_slider_state,
216
  inputs=[do_sample1],
 
224
  )
225
 
226
  if __name__ == "__main__":
227
+ demo.queue().launch()