shaheerawan3 commited on
Commit
072a834
·
verified ·
1 Parent(s): 308adb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from transformers import pipeline, set_seed
3
  from functools import lru_cache
4
 
5
- # === 1. Optional: Cache the pipeline loader to avoid reloading ===
6
  @lru_cache(maxsize=1)
7
  def get_generator(model_name: str):
8
  return pipeline(
@@ -12,15 +12,20 @@ def get_generator(model_name: str):
12
  device_map="auto"
13
  )
14
 
15
- # === 2. Chat function ===
16
  def chat(user_input, history, model_name, max_length, temperature, seed):
 
17
  if seed and seed > 0:
18
  set_seed(seed)
 
19
  generator = get_generator(model_name)
 
 
20
  prompt = (
21
  "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
22
  f"{user_input}\n[/INST]"
23
  )
 
24
  outputs = generator(
25
  prompt,
26
  max_length=max_length,
@@ -29,28 +34,36 @@ def chat(user_input, history, model_name, max_length, temperature, seed):
29
  num_return_sequences=1
30
  )
31
  response = outputs[0]["generated_text"].split("[/INST]")[-1].strip()
32
- history.append((user_input, response))
 
 
 
33
  return history, history
34
 
35
  # === 3. Build Gradio UI ===
36
  with gr.Blocks() as demo:
37
  gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)")
 
 
 
 
 
38
  with gr.Row():
39
  with gr.Column(scale=3):
40
- chatbox = gr.Chatbot()
41
- inp = gr.Textbox(show_label=False, placeholder="Type your message here...", lines=2)
42
- submit = gr.Button("Send")
43
  with gr.Column(scale=1):
44
  gr.Markdown("### Settings")
45
- model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name")
46
- max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens")
47
  temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature")
48
- seed = gr.Number(42, label="Random seed (0 disables)")
49
 
 
50
  submit.click(
51
  fn=chat,
52
- inputs=[inp, chatbox, model_name, max_length, temperature, seed],
53
- outputs=[chatbox, chatbox]
54
  )
55
 
56
  if __name__ == "__main__":
 
2
  from transformers import pipeline, set_seed
3
  from functools import lru_cache
4
 
5
+ # === 1. Cache the model loader once per session ===
6
  @lru_cache(maxsize=1)
7
  def get_generator(model_name: str):
8
  return pipeline(
 
12
  device_map="auto"
13
  )
14
 
15
+ # === 2. Chat callback ===
16
  def chat(user_input, history, model_name, max_length, temperature, seed):
17
+ # Set seed if provided
18
  if seed and seed > 0:
19
  set_seed(seed)
20
+ # Lazy-load the model
21
  generator = get_generator(model_name)
22
+
23
+ # Build prompt in Mistral’s instruction format
24
  prompt = (
25
  "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
26
  f"{user_input}\n[/INST]"
27
  )
28
+ # Generate response
29
  outputs = generator(
30
  prompt,
31
  max_length=max_length,
 
34
  num_return_sequences=1
35
  )
36
  response = outputs[0]["generated_text"].split("[/INST]")[-1].strip()
37
+
38
+ # Append to history as dicts for "messages" format
39
+ history.append({"role": "user", "content": user_input})
40
+ history.append({"role": "assistant", "content": response})
41
  return history, history
42
 
43
  # === 3. Build Gradio UI ===
44
  with gr.Blocks() as demo:
45
  gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)")
46
+
47
+ # Chatbot and session-state
48
+ chatbot = gr.Chatbot(type="messages") # :contentReference[oaicite:3]{index=3}
49
+ state = gr.State([]) # :contentReference[oaicite:4]{index=4}
50
+
51
  with gr.Row():
52
  with gr.Column(scale=3):
53
+ inp = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False)
54
+ submit= gr.Button("Send")
 
55
  with gr.Column(scale=1):
56
  gr.Markdown("### Settings")
57
+ model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name")
58
+ max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens")
59
  temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature")
60
+ seed = gr.Number(42, label="Random seed (0 disables)")
61
 
62
+ # Wire the button: inputs include the gr.State; outputs update both Chatbot and state
63
  submit.click(
64
  fn=chat,
65
+ inputs=[inp, state, model_name, max_length, temperature, seed],
66
+ outputs=[chatbot, state]
67
  )
68
 
69
  if __name__ == "__main__":