francismurray commited on
Commit
d85e33e
·
1 Parent(s): 4667b7d

feat: add max tokens control for model responses

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -24,14 +24,14 @@ AVAILABLE_MODELS = [
24
  # Initialize inference client
25
  inference_client = InferenceClient(token=HF_TOKEN)
26
 
27
- async def get_model_response(prompt, model_name, temperature_value, do_sample):
28
  """Get response from a Hugging Face model."""
29
  try:
30
  # Build kwargs dynamically
31
  generation_args = {
32
  "prompt": prompt,
33
  "model": model_name,
34
- "max_new_tokens": 100,
35
  "do_sample": do_sample,
36
  "return_full_text": False
37
  }
@@ -46,18 +46,23 @@ async def get_model_response(prompt, model_name, temperature_value, do_sample):
46
  None,
47
  partial(inference_client.text_generation, **generation_args)
48
  )
 
 
 
 
 
49
  return response
50
 
51
  except Exception as e:
52
  return f"Error: {str(e)}"
53
 
54
- async def process_single_response(prompt, model_name, temp, do_sample, chatbot):
55
  """Process a single model response and update its chatbot."""
56
- response = await get_model_response(prompt, model_name, temp, do_sample)
57
  chat_history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
58
  return chat_history
59
 
60
- async def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sample2):
61
  """Compare outputs from two selected models."""
62
  if not prompt.strip():
63
  empty_response = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}]
@@ -69,8 +74,8 @@ async def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sa
69
  yield initial_message, initial_message, gr.update(interactive=False)
70
 
71
  # Create tasks for both model responses
72
- task1 = asyncio.create_task(process_single_response(prompt, model1, temp1, do_sample1, "chatbot1"))
73
- task2 = asyncio.create_task(process_single_response(prompt, model2, temp2, do_sample2, "chatbot2"))
74
 
75
  chat1 = chat2 = initial_message
76
  start_time = asyncio.get_event_loop().time()
@@ -162,6 +167,13 @@ with gr.Blocks(css="""
162
  interactive=False,
163
  elem_classes=["disabled-slider"]
164
  )
 
 
 
 
 
 
 
165
  chatbot1 = gr.Chatbot(
166
  label="Model 1 Output",
167
  show_label=True,
@@ -188,6 +200,13 @@ with gr.Blocks(css="""
188
  interactive=False,
189
  elem_classes=["disabled-slider"]
190
  )
 
 
 
 
 
 
 
191
  chatbot2 = gr.Chatbot(
192
  label="Model 2 Output",
193
  show_label=True,
@@ -206,7 +225,7 @@ with gr.Blocks(css="""
206
  queue=False
207
  ).then(
208
  fn=compare_models,
209
- inputs=[prompt, model1_dropdown, model2_dropdown, temp1, temp2, do_sample1, do_sample2],
210
  outputs=[chatbot1, chatbot2, submit_btn],
211
  queue=True # Enable queuing for streaming updates
212
  )
 
24
  # Initialize inference client
25
  inference_client = InferenceClient(token=HF_TOKEN)
26
 
27
+ async def get_model_response(prompt, model_name, temperature_value, do_sample, max_tokens):
28
  """Get response from a Hugging Face model."""
29
  try:
30
  # Build kwargs dynamically
31
  generation_args = {
32
  "prompt": prompt,
33
  "model": model_name,
34
+ "max_new_tokens": max_tokens,
35
  "do_sample": do_sample,
36
  "return_full_text": False
37
  }
 
46
  None,
47
  partial(inference_client.text_generation, **generation_args)
48
  )
49
+
50
+ # Check if response might be truncated
51
+ if len(response) >= max_tokens * 4: # Rough estimate of tokens to characters ratio
52
+ response += "\n\n[Warning: Response may have been truncated. Try increasing the max tokens if the response seems incomplete.]"
53
+
54
  return response
55
 
56
  except Exception as e:
57
  return f"Error: {str(e)}"
58
 
59
+ async def process_single_response(prompt, model_name, temp, do_sample, max_tokens, chatbot):
60
  """Process a single model response and update its chatbot."""
61
+ response = await get_model_response(prompt, model_name, temp, do_sample, max_tokens)
62
  chat_history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
63
  return chat_history
64
 
65
+ async def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sample2, max_tokens1, max_tokens2):
66
  """Compare outputs from two selected models."""
67
  if not prompt.strip():
68
  empty_response = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}]
 
74
  yield initial_message, initial_message, gr.update(interactive=False)
75
 
76
  # Create tasks for both model responses
77
+ task1 = asyncio.create_task(process_single_response(prompt, model1, temp1, do_sample1, max_tokens1, "chatbot1"))
78
+ task2 = asyncio.create_task(process_single_response(prompt, model2, temp2, do_sample2, max_tokens2, "chatbot2"))
79
 
80
  chat1 = chat2 = initial_message
81
  start_time = asyncio.get_event_loop().time()
 
167
  interactive=False,
168
  elem_classes=["disabled-slider"]
169
  )
170
+ max_tokens1 = gr.Slider(
171
+ label="Maximum new tokens in response",
172
+ minimum=10,
173
+ maximum=2000,
174
+ step=10,
175
+ value=100
176
+ )
177
  chatbot1 = gr.Chatbot(
178
  label="Model 1 Output",
179
  show_label=True,
 
200
  interactive=False,
201
  elem_classes=["disabled-slider"]
202
  )
203
+ max_tokens2 = gr.Slider(
204
+ label="Maximum new tokens in response",
205
+ minimum=10,
206
+ maximum=2000,
207
+ step=10,
208
+ value=100
209
+ )
210
  chatbot2 = gr.Chatbot(
211
  label="Model 2 Output",
212
  show_label=True,
 
225
  queue=False
226
  ).then(
227
  fn=compare_models,
228
+ inputs=[prompt, model1_dropdown, model2_dropdown, temp1, temp2, do_sample1, do_sample2, max_tokens1, max_tokens2],
229
  outputs=[chatbot1, chatbot2, submit_btn],
230
  queue=True # Enable queuing for streaming updates
231
  )