Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, set_seed | |
from functools import lru_cache | |
# === 1. Cache the model loader once per session === | |
def get_generator(model_name: str): | |
return pipeline( | |
"text-generation", | |
model=model_name, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
# === 2. Chat callback === | |
def chat(user_input, history, model_name, max_length, temperature, seed): | |
# Set seed if provided | |
if seed and seed > 0: | |
set_seed(seed) | |
# Lazy-load the model | |
generator = get_generator(model_name) | |
# Build prompt in Mistral’s instruction format | |
prompt = ( | |
"[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n" | |
f"{user_input}\n[/INST]" | |
) | |
# Generate response | |
outputs = generator( | |
prompt, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
num_return_sequences=1 | |
) | |
response = outputs[0]["generated_text"].split("[/INST]")[-1].strip() | |
# Append to history as dicts for "messages" format | |
history.append({"role": "user", "content": user_input}) | |
history.append({"role": "assistant", "content": response}) | |
return history, history | |
# === 3. Build Gradio UI === | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)") | |
# Chatbot and session-state | |
chatbot = gr.Chatbot(type="messages") # :contentReference[oaicite:3]{index=3} | |
state = gr.State([]) # :contentReference[oaicite:4]{index=4} | |
with gr.Row(): | |
with gr.Column(scale=3): | |
inp = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False) | |
submit= gr.Button("Send") | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name") | |
max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens") | |
temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature") | |
seed = gr.Number(42, label="Random seed (0 disables)") | |
# Wire the button: inputs include the gr.State; outputs update both Chatbot and state | |
submit.click( | |
fn=chat, | |
inputs=[inp, state, model_name, max_length, temperature, seed], | |
outputs=[chatbot, state] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |