import gradio as gr from transformers import pipeline, set_seed from functools import lru_cache # === 1. Cache the model loader once per session === @lru_cache(maxsize=1) def get_generator(model_name: str): return pipeline( "text-generation", model=model_name, trust_remote_code=True, device_map="auto" ) # === 2. Chat callback === def chat(user_input, history, model_name, max_length, temperature, seed): # Set seed if provided if seed and seed > 0: set_seed(seed) # Lazy-load the model generator = get_generator(model_name) # Build prompt in Mistral’s instruction format prompt = ( "[INST] <>\nYou are a helpful assistant.\n<>\n\n" f"{user_input}\n[/INST]" ) # Generate response outputs = generator( prompt, max_length=max_length, temperature=temperature, do_sample=True, num_return_sequences=1 ) response = outputs[0]["generated_text"].split("[/INST]")[-1].strip() # Append to history as dicts for "messages" format history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) return history, history # === 3. Build Gradio UI === with gr.Blocks() as demo: gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)") # Chatbot and session-state chatbot = gr.Chatbot(type="messages") # :contentReference[oaicite:3]{index=3} state = gr.State([]) # :contentReference[oaicite:4]{index=4} with gr.Row(): with gr.Column(scale=3): inp = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False) submit= gr.Button("Send") with gr.Column(scale=1): gr.Markdown("### Settings") model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name") max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens") temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature") seed = gr.Number(42, label="Random seed (0 disables)") # Wire the button: inputs include the gr.State; outputs update both Chatbot and state submit.click( fn=chat, inputs=[inp, state, model_name, max_length, temperature, seed], outputs=[chatbot, state] ) if __name__ == "__main__": demo.launch()