File size: 2,491 Bytes
308adb1
 
 
24f0c3b
072a834
308adb1
 
24f0c3b
 
 
 
308adb1
24f0c3b
 
072a834
308adb1
072a834
308adb1
 
072a834
308adb1
072a834
 
24f0c3b
 
 
 
072a834
24f0c3b
 
 
 
 
 
 
308adb1
072a834
 
 
 
24f0c3b
 
308adb1
24f0c3b
308adb1
072a834
 
 
 
 
24f0c3b
 
072a834
 
24f0c3b
308adb1
072a834
 
308adb1
072a834
24f0c3b
072a834
24f0c3b
 
072a834
 
24f0c3b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import pipeline, set_seed
from functools import lru_cache

# === 1. Cache the model loader once per session ===
@lru_cache(maxsize=1)
def get_generator(model_name: str):
    return pipeline(
        "text-generation",
        model=model_name,
        trust_remote_code=True,
        device_map="auto"
    )

# === 2. Chat callback ===
def chat(user_input, history, model_name, max_length, temperature, seed):
    # Set seed if provided
    if seed and seed > 0:
        set_seed(seed)
    # Lazy-load the model
    generator = get_generator(model_name)

    # Build prompt in Mistral’s instruction format
    prompt = (
        "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
        f"{user_input}\n[/INST]"
    )
    # Generate response
    outputs = generator(
        prompt,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
        num_return_sequences=1
    )
    response = outputs[0]["generated_text"].split("[/INST]")[-1].strip()

    # Append to history as dicts for "messages" format
    history.append({"role": "user",      "content": user_input})
    history.append({"role": "assistant", "content": response})
    return history, history

# === 3. Build Gradio UI ===
with gr.Blocks() as demo:
    gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)")

    # Chatbot and session-state
    chatbot = gr.Chatbot(type="messages")                          # :contentReference[oaicite:3]{index=3}
    state   = gr.State([])                                         # :contentReference[oaicite:4]{index=4}

    with gr.Row():
        with gr.Column(scale=3):
            inp    = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False)
            submit= gr.Button("Send")
        with gr.Column(scale=1):
            gr.Markdown("### Settings")
            model_name  = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name")
            max_length  = gr.Slider(50, 1024, 256, step=50, label="Max tokens")
            temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature")
            seed        = gr.Number(42, label="Random seed (0 disables)")

    # Wire the button: inputs include the gr.State; outputs update both Chatbot and state
    submit.click(
        fn=chat,
        inputs=[inp, state, model_name, max_length, temperature, seed],
        outputs=[chatbot, state]
    )

if __name__ == "__main__":
    demo.launch()