MixtureOfInputs / app.py
yzhuang's picture
Update app.py
bfbfc37 verified
from __future__ import annotations
import os
import openai
import gradio as gr
import server
# ──────────────────────────────────────────────────────────────────────────────
# OpenAI client configuration
# ──────────────────────────────────────────────────────────────────────────────
# ``openai`` still expects an API key even if the backend ignores it, so we use
# a dummy value when none is provided. The *base_url* points to the local
# vLLM server that speaks the OpenAI REST dialect.
# -----------------------------------------------------------------------------
openai_api_key = "EMPTY"
openai_api_base = "http://0.0.0.0:8000/v1"
client = openai.OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
# ──────────────────────────────────────────────────────────────────────────────
# Chat handler
# ──────────────────────────────────────────────────────────────────────────────
def stream_completion(message: str,
history: list[tuple[str, str]],
max_tokens: int,
temperature: float,
top_p: float,
beta: float):
"""Gradio callback that yields streaming assistant replies.
The function reconstructs the conversation *excluding* any system prompt
and then calls ``openai.chat.completions.create`` with ``stream=True``.
Each incoming delta is appended to an ``assistant`` buffer which is sent
back to the Chatbot component for real‑time display.
"""
# Build OpenAI‑style message list from prior turns
messages: list[dict[str, str]] = []
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Current user input comes last
messages.append({"role": "user", "content": message})
os.environ["MIXINPUTS_BETA"] = str(beta)
#try:
# Kick off streaming completion
response = client.chat.completions.create(
model="Qwen/Qwen3-4B",
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
assistant = response.choices[0].message.content
yield history + [(message, assistant)] # live update
# ──────────────────────────────────────────────────────────────────────────────
# Gradio UI
# ──────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
gr.Markdown(
"## 🎨 Mixture of Inputs (MoI) Demo with Qwen3-4B\n"
"Streaming vLLM demo with dynamic **beta** adjustment in MoI, feel how it affects the model!\n"
"(higher beta β†’ less blending).\n"
"πŸ“•Paper: https://arxiv.org/abs/2505.14827 \n"
"πŸ’»Code: https://github.com/EvanZhuang/mixinputs \n"
)
with gr.Row(): # sliders first
beta = gr.Slider(0.0, 10.0, value=1.0, step=0.1, label="MoI Ξ²")
temperature = gr.Slider(0.1, 1.0, value=0.6, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.80, step=0.05, label="Top‑p")
max_tokens = gr.Slider(1, 3072, value=2048, step=1, label="Max new tokens")
chatbot = gr.Chatbot(height=450)
user_box = gr.Textbox(placeholder="Type a message and press Enter…", show_label=False)
clear_btn = gr.Button("Clear chat")
user_box.submit(
fn=stream_completion,
inputs=[user_box, chatbot, max_tokens, temperature, top_p, beta],
outputs=chatbot,
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch()