shedacoding / app.py
shaheerawan3's picture
Update app.py
072a834 verified
import gradio as gr
from transformers import pipeline, set_seed
from functools import lru_cache
# === 1. Cache the model loader once per session ===
@lru_cache(maxsize=1)
def get_generator(model_name: str):
return pipeline(
"text-generation",
model=model_name,
trust_remote_code=True,
device_map="auto"
)
# === 2. Chat callback ===
def chat(user_input, history, model_name, max_length, temperature, seed):
# Set seed if provided
if seed and seed > 0:
set_seed(seed)
# Lazy-load the model
generator = get_generator(model_name)
# Build prompt in Mistral’s instruction format
prompt = (
"[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
f"{user_input}\n[/INST]"
)
# Generate response
outputs = generator(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
num_return_sequences=1
)
response = outputs[0]["generated_text"].split("[/INST]")[-1].strip()
# Append to history as dicts for "messages" format
history.append({"role": "user", "content": user_input})
history.append({"role": "assistant", "content": response})
return history, history
# === 3. Build Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)")
# Chatbot and session-state
chatbot = gr.Chatbot(type="messages") # :contentReference[oaicite:3]{index=3}
state = gr.State([]) # :contentReference[oaicite:4]{index=4}
with gr.Row():
with gr.Column(scale=3):
inp = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False)
submit= gr.Button("Send")
with gr.Column(scale=1):
gr.Markdown("### Settings")
model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name")
max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens")
temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature")
seed = gr.Number(42, label="Random seed (0 disables)")
# Wire the button: inputs include the gr.State; outputs update both Chatbot and state
submit.click(
fn=chat,
inputs=[inp, state, model_name, max_length, temperature, seed],
outputs=[chatbot, state]
)
if __name__ == "__main__":
demo.launch()