shaheerawan3 commited on
Commit
308adb1
·
verified ·
1 Parent(s): 8bbc8a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -56
app.py CHANGED
@@ -1,38 +1,26 @@
1
- import gradio as gr # Gradio for UI :contentReference[oaicite:3]{index=3}
2
- from transformers import pipeline, set_seed # Transformers pipeline :contentReference[oaicite:4]{index=4}
 
3
 
4
- # Set random seed reproducibility
5
- def maybe_set_seed(seed: int):
6
- if seed and seed > 0:
7
- set_seed(seed)
8
-
9
- # Load the model once per Space session
10
- @gr.cache_resource # caches across sessions :contentReference[oaicite:5]{index=5}
11
- def load_generator(model_name: str):
12
  return pipeline(
13
  "text-generation",
14
  model=model_name,
15
  trust_remote_code=True,
16
- device_map="auto", # auto GPU/CPU placement :contentReference[oaicite:6]{index=6}
17
  )
18
 
19
- # The chat function called on each submit
20
- def chat(
21
- user_input: str,
22
- history: list[tuple[str, str]],
23
- model_name: str,
24
- max_length: int,
25
- temperature: float,
26
- seed: int
27
- ):
28
- maybe_set_seed(seed)
29
- generator = load_generator(model_name)
30
- # Prepare instruction-wrapped prompt
31
  prompt = (
32
  "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
33
  f"{user_input}\n[/INST]"
34
  )
35
- # Generate response
36
  outputs = generator(
37
  prompt,
38
  max_length=max_length,
@@ -40,51 +28,30 @@ def chat(
40
  do_sample=True,
41
  num_return_sequences=1
42
  )
43
- raw = outputs[0]["generated_text"]
44
- # Extract only assistant’s reply
45
- response = raw.split("[/INST]")[-1].strip()
46
- # Append to chat history
47
  history.append((user_input, response))
48
  return history, history
49
 
50
- # Build the Gradio Blocks interface
51
  with gr.Blocks() as demo:
52
- gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)") # title :contentReference[oaicite:7]{index=7}
53
  with gr.Row():
54
  with gr.Column(scale=3):
55
- chatbox = gr.Chatbot(label="Chat History")
56
- with gr.Row():
57
- inp = gr.Textbox(
58
- placeholder="Type your message here...",
59
- show_label=False,
60
- lines=2
61
- )
62
- submit = gr.Button("Send")
63
  with gr.Column(scale=1):
64
- gr.Markdown("### Settings 🚀")
65
- model_name = gr.Textbox(
66
- label="Model name",
67
- value="mistralai/Mistral-7B-Instruct-v0.3"
68
- )
69
- max_length = gr.Slider(
70
- minimum=50, maximum=1024, step=50,
71
- value=256, label="Max tokens"
72
- )
73
- temperature = gr.Slider(
74
- minimum=0.0, maximum=1.0, step=0.05,
75
- value=0.7, label="Temperature"
76
- )
77
- seed = gr.Number(
78
- value=42, label="Random seed (0 disables)"
79
- )
80
 
81
- # Wire up the event
82
  submit.click(
83
  fn=chat,
84
  inputs=[inp, chatbox, model_name, max_length, temperature, seed],
85
  outputs=[chatbox, chatbox]
86
  )
87
 
88
- # Launch the app; in Spaces, no need to set share=True :contentReference[oaicite:8]{index=8}
89
  if __name__ == "__main__":
90
  demo.launch()
 
1
+ import gradio as gr
2
+ from transformers import pipeline, set_seed
3
+ from functools import lru_cache
4
 
5
+ # === 1. Optional: Cache the pipeline loader to avoid reloading ===
6
+ @lru_cache(maxsize=1)
7
+ def get_generator(model_name: str):
 
 
 
 
 
8
  return pipeline(
9
  "text-generation",
10
  model=model_name,
11
  trust_remote_code=True,
12
+ device_map="auto"
13
  )
14
 
15
+ # === 2. Chat function ===
16
+ def chat(user_input, history, model_name, max_length, temperature, seed):
17
+ if seed and seed > 0:
18
+ set_seed(seed)
19
+ generator = get_generator(model_name)
 
 
 
 
 
 
 
20
  prompt = (
21
  "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n"
22
  f"{user_input}\n[/INST]"
23
  )
 
24
  outputs = generator(
25
  prompt,
26
  max_length=max_length,
 
28
  do_sample=True,
29
  num_return_sequences=1
30
  )
31
+ response = outputs[0]["generated_text"].split("[/INST]")[-1].strip()
 
 
 
32
  history.append((user_input, response))
33
  return history, history
34
 
35
+ # === 3. Build Gradio UI ===
36
  with gr.Blocks() as demo:
37
+ gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)")
38
  with gr.Row():
39
  with gr.Column(scale=3):
40
+ chatbox = gr.Chatbot()
41
+ inp = gr.Textbox(show_label=False, placeholder="Type your message here...", lines=2)
42
+ submit = gr.Button("Send")
 
 
 
 
 
43
  with gr.Column(scale=1):
44
+ gr.Markdown("### Settings")
45
+ model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name")
46
+ max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens")
47
+ temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature")
48
+ seed = gr.Number(42, label="Random seed (0 disables)")
 
 
 
 
 
 
 
 
 
 
 
49
 
 
50
  submit.click(
51
  fn=chat,
52
  inputs=[inp, chatbox, model_name, max_length, temperature, seed],
53
  outputs=[chatbox, chatbox]
54
  )
55
 
 
56
  if __name__ == "__main__":
57
  demo.launch()