llm-compare / app.py
francismurray's picture
reduce max tokens default value to 10
895116b
import os
import gradio as gr
import asyncio
from dotenv import load_dotenv
from huggingface_hub import InferenceClient, hf_hub_download, model_info
from functools import partial
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("Please set HF_TOKEN environment variable")
# Available models
AVAILABLE_MODELS = [
"HuggingFaceH4/zephyr-7b-beta",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Base-2407",
"meta-llama/Llama-2-70b-hf",
"aaditya/Llama3-OpenBioLLM-8B",
]
# Initialize inference client
inference_client = InferenceClient(token=HF_TOKEN)
def get_model_card_html(model_name, title):
"""Fetch and format model card information."""
try:
info = model_info(model_name, token=HF_TOKEN)
return f"""
<div class="model-card-container">
<h3>{info.modelId}</h3>
<p><strong>Pipeline Tag:</strong> {info.pipeline_tag or 'Not specified'}</p>
<p><strong>Downloads:</strong> {info.downloads:,}</p>
<p><strong>Likes:</strong> {info.likes:,}</p>
<p><a href="https://huggingface.co/{model_name}" target="_blank">View on Hugging Face</a></p>
</div>
"""
except Exception as e:
return f"""
<div class="model-card-container">
<h3>{model_name}</h3>
<p>Unable to load full model card information.</p>
<p><a href="https://huggingface.co/{model_name}" target="_blank">View on Hugging Face</a></p>
</div>
"""
async def get_model_response(prompt, model_name, temperature_value, do_sample, max_tokens):
"""Get response from a Hugging Face model."""
try:
# Build kwargs dynamically
generation_args = {
"prompt": prompt,
"model": model_name,
"max_new_tokens": max_tokens,
"do_sample": do_sample,
"return_full_text": False
}
# Only include temperature if sampling is enabled
if do_sample and temperature_value > 0:
generation_args["temperature"] = temperature_value
# Run the inference in a thread pool to not block the event loop
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
partial(inference_client.text_generation, **generation_args)
)
# Check if response might be truncated
if len(response) >= max_tokens * 4: # Rough estimate of tokens to characters ratio
response += "\n\n[Warning: Response may have been truncated. Try increasing the max tokens if the response seems incomplete.]"
return response
except Exception as e:
return f"Error: {str(e)}"
async def process_single_response(prompt, model_name, temp, do_sample, max_tokens, chatbot):
"""Process a single model response and update its chatbot."""
response = await get_model_response(prompt, model_name, temp, do_sample, max_tokens)
chat_history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
return chat_history
async def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sample2, max_tokens1, max_tokens2):
"""Compare outputs from two selected models."""
if not prompt.strip():
empty_response = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}]
yield empty_response, empty_response, gr.update(interactive=True)
return # Exit the generator
# Initialize with "Generating..." messages
initial_message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Generating..."}]
yield initial_message, initial_message, gr.update(interactive=False)
# Create tasks for both model responses
task1 = asyncio.create_task(process_single_response(prompt, model1, temp1, do_sample1, max_tokens1, "chatbot1"))
task2 = asyncio.create_task(process_single_response(prompt, model2, temp2, do_sample2, max_tokens2, "chatbot2"))
chat1 = chat2 = initial_message
start_time = asyncio.get_event_loop().time()
try:
while not (task1.done() and task2.done()):
# Update the messages with elapsed time
elapsed = round(asyncio.get_event_loop().time() - start_time, 1)
chat1_content = chat1[1]["content"]
chat2_content = chat2[1]["content"]
if not task1.done():
chat1 = [{"role": "user", "content": prompt},
{"role": "assistant", "content": f"Generating... ({elapsed:.1f}s)"}]
if not task2.done():
chat2 = [{"role": "user", "content": prompt},
{"role": "assistant", "content": f"Generating... ({elapsed:.1f}s)"}]
# Check if any task completed
done, pending = await asyncio.wait([t for t in [task1, task2] if not t.done()],
timeout=0.1,
return_when=asyncio.FIRST_COMPLETED)
for task in done:
if task == task1:
chat1 = await task1
else:
chat2 = await task2
yield chat1, chat2, gr.update(interactive=False)
# Ensure we have both final results
if not task1.done():
chat1 = await task1
if not task2.done():
chat2 = await task2
# Final yield with both results
yield chat1, chat2, gr.update(interactive=True)
except Exception as e:
error_message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": f"Error: {str(e)}"}]
yield error_message, error_message, gr.update(interactive=True)
# Update temperature slider interactivity based on sampling checkbox
def update_slider_state(enabled):
return [
gr.update(interactive=enabled),
gr.update(
elem_classes=[] if enabled else ["disabled-slider"],
value=0 if not enabled else None
)
]
# Create the Gradio interface
with gr.Blocks(css="""
.disabled-slider { opacity: 0.5; pointer-events: none; }
.model-card-container {
background-color: #f8f9fa;
font-size: 14px;
color: #666;
}
.model-card-container h3 {
margin: 0;
color: black;
}
.model-card-container p {
margin: 5px 0;
}
""") as demo:
gr.Markdown("# LLM Comparison Tool")
gr.Markdown("Using HuggingFace's Inference API, compare outputs from different `text-generation` models side by side.")
with gr.Row():
prompt = gr.Textbox(
label="Enter your prompt",
placeholder="Type your prompt here...",
lines=3
)
with gr.Row():
submit_btn = gr.Button("Generate Responses")
with gr.Row():
with gr.Column():
model1_dropdown = gr.Dropdown(
choices=AVAILABLE_MODELS,
value=AVAILABLE_MODELS[0],
label="Select Model 1"
)
model1_card = gr.HTML(
value=get_model_card_html(AVAILABLE_MODELS[0], "Model 1 Information"),
elem_classes=["model-card-container"]
)
do_sample1 = gr.Checkbox(
label="Enable sampling (random outputs)",
value=False
)
temp1 = gr.Slider(
label="Temperature (Higher = more creative, lower = more predictable)",
minimum=0,
maximum=1,
step=0.1,
value=0.0,
interactive=False,
elem_classes=["disabled-slider"]
)
max_tokens1 = gr.Slider(
label="Maximum new tokens in response",
minimum=10,
maximum=2000,
step=10,
value=10
)
chatbot1 = gr.Chatbot(
label="Model 1 Output",
show_label=True,
height=300,
type="messages"
)
with gr.Column():
model2_dropdown = gr.Dropdown(
choices=AVAILABLE_MODELS,
value=AVAILABLE_MODELS[1],
label="Select Model 2"
)
model2_card = gr.HTML(
value=get_model_card_html(AVAILABLE_MODELS[1], "Model 2 Information"),
elem_classes=["model-card-container"]
)
do_sample2 = gr.Checkbox(
label="Enable sampling (random outputs)",
value=False
)
temp2 = gr.Slider(
label="Temperature (Higher = more creative, lower = more predictable)",
minimum=0,
maximum=1,
step=0.1,
value=0.0,
interactive=False,
elem_classes=["disabled-slider"]
)
max_tokens2 = gr.Slider(
label="Maximum new tokens in response",
minimum=10,
maximum=2000,
step=10,
value=10
)
chatbot2 = gr.Chatbot(
label="Model 2 Output",
show_label=True,
height=300,
type="messages"
)
def start_loading():
return gr.update(interactive=False)
# Handle form submission
submit_btn.click(
fn=start_loading,
inputs=None,
outputs=submit_btn,
queue=False
).then(
fn=compare_models,
inputs=[prompt, model1_dropdown, model2_dropdown, temp1, temp2, do_sample1, do_sample2, max_tokens1, max_tokens2],
outputs=[chatbot1, chatbot2, submit_btn],
queue=True # Enable queuing for streaming updates
)
# Update model cards when models are changed
model1_dropdown.change(
fn=lambda x: get_model_card_html(x, "Model 1 Information"),
inputs=[model1_dropdown],
outputs=[model1_card]
)
model2_dropdown.change(
fn=lambda x: get_model_card_html(x, "Model 2 Information"),
inputs=[model2_dropdown],
outputs=[model2_card]
)
# Existing event handlers
do_sample1.change(
fn=update_slider_state,
inputs=[do_sample1],
outputs=[temp1, temp1]
)
do_sample2.change(
fn=update_slider_state,
inputs=[do_sample2],
outputs=[temp2, temp2]
)
if __name__ == "__main__":
demo.queue().launch()