File size: 4,930 Bytes
2d7f359
 
 
 
 
2a57165
2d7f359
 
 
 
 
 
 
 
25fd9cd
e4a7eb4
25fd9cd
e89cacf
25fd9cd
 
 
2d7f359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ab546e
2d7f359
7ab546e
2d7f359
7ab546e
 
2d7f359
f873ce7
4c07c1e
a28682a
 
4cf93d3
 
25fd9cd
4cf93d3
 
 
 
 
 
 
592c12e
25fd9cd
c1965a3
2d7f359
 
 
 
 
c2ec273
c1965a3
3f54d28
81f72df
 
 
 
c1965a3
 
39209e4
2d7f359
 
 
bfbfc37
4c07c1e
39209e4
 
c1965a3
4c07c1e
c1965a3
39209e4
c1965a3
 
 
2d7f359
c1965a3
a846510
2d7f359
f873ce7
c1965a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import annotations

import os
import openai
import gradio as gr
import server

# ──────────────────────────────────────────────────────────────────────────────
# OpenAI client configuration
# ──────────────────────────────────────────────────────────────────────────────
# ``openai`` still expects an API key even if the backend ignores it, so we use
# a dummy value when none is provided.  The *base_url* points to the local
# vLLM server that speaks the OpenAI REST dialect.
# -----------------------------------------------------------------------------
openai_api_key = "EMPTY"
openai_api_base = "http://0.0.0.0:8000/v1"

client = openai.OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

# ──────────────────────────────────────────────────────────────────────────────
# Chat handler
# ──────────────────────────────────────────────────────────────────────────────

def stream_completion(message: str,
                      history: list[tuple[str, str]],
                      max_tokens: int,
                      temperature: float,
                      top_p: float,
                      beta: float):
    """Gradio callback that yields streaming assistant replies.

    The function reconstructs the conversation *excluding* any system prompt
    and then calls ``openai.chat.completions.create`` with ``stream=True``.
    Each incoming delta is appended to an ``assistant`` buffer which is sent
    back to the Chatbot component for real‑time display.
    """

    # Build OpenAI‑style message list from prior turns
    messages: list[dict[str, str]] = []
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})

    # Current user input comes last
    messages.append({"role": "user", "content": message})

    os.environ["MIXINPUTS_BETA"] = str(beta)

    #try:
    # Kick off streaming completion
    response = client.chat.completions.create(
        model="Qwen/Qwen3-4B",
        messages=messages,
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
    )

    assistant = response.choices[0].message.content
    yield history + [(message, assistant)]  # live update


# ──────────────────────────────────────────────────────────────────────────────
# Gradio UI
# ──────────────────────────────────────────────────────────────────────────────

with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
    gr.Markdown(
        "## 🎨 Mixture of Inputs (MoI) Demo with Qwen3-4B\n"
        "Streaming vLLM demo with dynamic **beta** adjustment in MoI, feel how it affects the model!\n"
        "(higher beta β†’ less blending).\n"
        "πŸ“•Paper: https://arxiv.org/abs/2505.14827 \n"
        "πŸ’»Code: https://github.com/EvanZhuang/mixinputs \n"
    )

    with gr.Row():  # sliders first
        beta        = gr.Slider(0.0, 10.0, value=1.0,  step=0.1,  label="MoI Ξ²")
        temperature = gr.Slider(0.1, 1.0,  value=0.6,  step=0.1,  label="Temperature")
        top_p       = gr.Slider(0.1, 1.0,  value=0.80, step=0.05, label="Top‑p")
        max_tokens  = gr.Slider(1,   3072, value=2048,  step=1,    label="Max new tokens")

    chatbot   = gr.Chatbot(height=450)
    user_box  = gr.Textbox(placeholder="Type a message and press Enter…", show_label=False)
    clear_btn = gr.Button("Clear chat")

    user_box.submit(
        fn=stream_completion,
        inputs=[user_box, chatbot, max_tokens, temperature, top_p, beta],
        outputs=chatbot,
    )

    clear_btn.click(lambda: None, None, chatbot, queue=False)

# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    demo.launch()