Spaces:
Sleeping
Sleeping
File size: 2,491 Bytes
308adb1 24f0c3b 072a834 308adb1 24f0c3b 308adb1 24f0c3b 072a834 308adb1 072a834 308adb1 072a834 308adb1 072a834 24f0c3b 072a834 24f0c3b 308adb1 072a834 24f0c3b 308adb1 24f0c3b 308adb1 072a834 24f0c3b 072a834 24f0c3b 308adb1 072a834 308adb1 072a834 24f0c3b 072a834 24f0c3b 072a834 24f0c3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from transformers import pipeline, set_seed
from functools import lru_cache
# === 1. Cache the model loader once per session ===
@lru_cache(maxsize=1)
def get_generator(model_name: str):
return pipeline(
"text-generation",
model=model_name,
trust_remote_code=True,
device_map="auto"
)
# === 2. Chat callback ===
def chat(user_input, history, model_name, max_length, temperature, seed):
# Set seed if provided
if seed and seed > 0:
set_seed(seed)
# Lazy-load the model
generator = get_generator(model_name)
# Build prompt in Mistral’s instruction format
prompt = (
"[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
f"{user_input}\n[/INST]"
)
# Generate response
outputs = generator(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
num_return_sequences=1
)
response = outputs[0]["generated_text"].split("[/INST]")[-1].strip()
# Append to history as dicts for "messages" format
history.append({"role": "user", "content": user_input})
history.append({"role": "assistant", "content": response})
return history, history
# === 3. Build Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)")
# Chatbot and session-state
chatbot = gr.Chatbot(type="messages") # :contentReference[oaicite:3]{index=3}
state = gr.State([]) # :contentReference[oaicite:4]{index=4}
with gr.Row():
with gr.Column(scale=3):
inp = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False)
submit= gr.Button("Send")
with gr.Column(scale=1):
gr.Markdown("### Settings")
model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name")
max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens")
temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature")
seed = gr.Number(42, label="Random seed (0 disables)")
# Wire the button: inputs include the gr.State; outputs update both Chatbot and state
submit.click(
fn=chat,
inputs=[inp, state, model_name, max_length, temperature, seed],
outputs=[chatbot, state]
)
if __name__ == "__main__":
demo.launch()
|