phanerozoic commited on
Commit
885a86a
Β·
verified Β·
1 Parent(s): 6f67928

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -74
app.py CHANGED
@@ -1,10 +1,5 @@
1
  """
2
- SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot Space
3
- ----------------------------------------------
4
- β€’ IBM Granite‑3.3‑2B‑Instruct (Apache‑2), runs in HF CPU Space.
5
- β€’ Keeps last MAX_TURNS exchanges to fit context.
6
- β€’ β€œClearΒ Chat” button resets conversation.
7
- β€’ Extensive error‑handling: model‑load, inference, bad input.
8
  """
9
 
10
  import re
@@ -16,14 +11,14 @@ from transformers import (
16
  )
17
  from transformers.utils import logging as hf_logging
18
 
19
- # ────────── Configuration ───────────────────────────────────────────────────
20
  hf_logging.set_verbosity_error()
21
  LOG = hf_logging.get_logger("SchoolSpirit")
22
 
23
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
24
- MAX_TURNS = 6 # history turns to keep
25
- MAX_TOKENS = 200 # response length
26
- MAX_INPUT_CH = 400 # user message length guard
27
 
28
  SYSTEM_MSG = (
29
  "You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that "
@@ -32,13 +27,11 @@ SYSTEM_MSG = (
32
  "say so and suggest contacting a human. Never request personal data."
33
  )
34
 
35
- # ────────── Model loading with fail‑safe ────────────────────────────────────
36
  try:
37
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
38
  model = AutoModelForCausalLM.from_pretrained(
39
- MODEL_ID,
40
- device_map="auto",
41
- torch_dtype="auto",
42
  )
43
  generator = pipeline(
44
  "text-generation",
@@ -54,77 +47,77 @@ except Exception as exc: # noqa: BLE001
54
  generator = None
55
  LOG.error(MODEL_ERR)
56
 
57
- # ────────── Helper utilities ────────────────────────────────────────────────
58
- def truncate(hist):
59
- """Return last MAX_TURNS (user,bot) tuples."""
60
- return hist[-MAX_TURNS:] if len(hist) > MAX_TURNS else hist
61
-
62
-
63
  def clean(text: str) -> str:
64
- """Normalize whitespace and guarantee non‑empty."""
65
  return re.sub(r"\s+", " ", text.strip()) or "…"
66
 
67
 
68
- def safe_generate(prompt: str) -> str:
69
- """Call model.generate, catch & log any error, always return a string."""
70
- try:
71
- completion = generator(prompt)[0]["generated_text"]
72
- reply = clean(completion.split("AI:", 1)[-1])
73
- except Exception as err: # noqa: BLE001
74
- LOG.error(f"Inference error: {err}")
75
- reply = (
76
- "Sorryβ€”I'm having trouble right now. "
77
- "Please try again in a moment."
78
- )
79
- return reply
80
 
81
- # ────────── Chat callback ───────────────────────────────────────────────────
82
- def chat(history, user_msg):
83
- history = list(history) # guaranteed list of tuples
 
 
 
84
 
85
- # Fatal start‑up failure
86
  if MODEL_ERR:
87
- history.append((user_msg, MODEL_ERR))
88
- return history, ""
 
 
89
 
90
- user_msg = clean(user_msg or "")
 
 
91
  if not user_msg:
92
- history.append(("", "Please enter a message."))
93
- return history, ""
94
  if len(user_msg) > MAX_INPUT_CH:
95
- history.append(
96
- (user_msg, f"Message too long (>{MAX_INPUT_CH} chars).")
97
  )
98
- return history, ""
99
 
100
- history = truncate(history)
101
 
102
  # Build prompt
103
- prompt_lines = [SYSTEM_MSG]
104
- for u, a in history:
105
- prompt_lines += [f"User: {u}", f"AI: {a}"]
106
- prompt_lines += [f"User: {user_msg}", "AI:"]
107
- prompt = "\n".join(prompt_lines)
108
-
109
- reply = safe_generate(prompt)
110
-
111
- history.append((user_msg, reply))
112
- return history, ""
113
-
114
- # ────────── Clear chat callback ─────────────────────────────────────────────
115
- def clear_chat():
116
- return [], ""
117
-
118
- # ────────── UI definition ───────────────────────────────────────────────────
119
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
120
- gr.Markdown("# SchoolSpiritΒ AI Chat")
121
- chatbot = gr.Chatbot(type="tuples")
122
- msg_box = gr.Textbox(placeholder="Ask me anything about SchoolSpiritΒ AI…")
123
- send_btn = gr.Button("Send")
124
- clear_btn = gr.Button("Clear Chat", variant="secondary")
125
-
126
- send_btn.click(chat, [chatbot, msg_box], [chatbot, msg_box])
127
- msg_box.submit(chat, [chatbot, msg_box], [chatbot, msg_box])
128
- clear_btn.click(clear_chat, outputs=[chatbot, msg_box])
129
-
130
- demo.launch()
 
 
 
 
 
 
1
  """
2
+ SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot Space (messages API)
 
 
 
 
 
3
  """
4
 
5
  import re
 
11
  )
12
  from transformers.utils import logging as hf_logging
13
 
14
+ # ─── Config ────────────────────────────────────────────────────────────────
15
  hf_logging.set_verbosity_error()
16
  LOG = hf_logging.get_logger("SchoolSpirit")
17
 
18
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
19
+ MAX_TURNS = 6 # keep last N user/assistant pairs
20
+ MAX_TOKENS = 200
21
+ MAX_INPUT_CH = 400
22
 
23
  SYSTEM_MSG = (
24
  "You are SchoolSpiritΒ AI, the upbeat digital mascot for a company that "
 
27
  "say so and suggest contacting a human. Never request personal data."
28
  )
29
 
30
+ # ─── Model load with fail‑safe ─────────────────────────────────────────────
31
  try:
32
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
33
  model = AutoModelForCausalLM.from_pretrained(
34
+ MODEL_ID, device_map="auto", torch_dtype="auto"
 
 
35
  )
36
  generator = pipeline(
37
  "text-generation",
 
47
  generator = None
48
  LOG.error(MODEL_ERR)
49
 
50
+ # ─── Helpers ───────────────────────────────────────────────────────────────
 
 
 
 
 
51
  def clean(text: str) -> str:
 
52
  return re.sub(r"\s+", " ", text.strip()) or "…"
53
 
54
 
55
+ def trim(messages):
56
+ """Keep system + last MAX_TURNS*2 messages."""
57
+ if len(messages) <= 1 + MAX_TURNS * 2:
58
+ return messages
59
+ # messages[0] is system; keep last N pairs
60
+ return [messages[0]] + messages[-MAX_TURNS * 2:]
 
 
 
 
 
 
61
 
62
+ # ─── Chat callback (messages API) ──────────────────────────────────────────
63
+ def chat_fn(messages):
64
+ """
65
+ messages: list[dict] with keys 'role' ('system'/'user'/'assistant')
66
+ returns the updated messages list
67
+ """
68
 
69
+ # Startup failure?
70
  if MODEL_ERR:
71
+ messages.append(
72
+ {"role": "assistant", "content": MODEL_ERR}
73
+ )
74
+ return messages
75
 
76
+ # Validate user input
77
+ user_msg = messages[-1]["content"]
78
+ user_msg = clean(user_msg)
79
  if not user_msg:
80
+ messages[-1]["content"] = "Please enter a message."
81
+ return messages
82
  if len(user_msg) > MAX_INPUT_CH:
83
+ messages[-1]["content"] = (
84
+ f"Message too long (> {MAX_INPUT_CH} chars)."
85
  )
86
+ return messages
87
 
88
+ messages = trim(messages)
89
 
90
  # Build prompt
91
+ prompt = ""
92
+ for m in messages:
93
+ if m["role"] == "system":
94
+ prompt += m["content"] + "\n"
95
+ elif m["role"] == "user":
96
+ prompt += f"User: {m['content']}\n"
97
+ else: # assistant
98
+ prompt += f"AI: {m['content']}\n"
99
+ prompt += "AI:"
100
+
101
+ # Generate
102
+ try:
103
+ out = generator(prompt)[0]["generated_text"]
104
+ reply = clean(out.split("AI:", 1)[-1])
105
+ except Exception as err: # noqa: BLE001
106
+ LOG.error(f"Inference error: {err}")
107
+ reply = (
108
+ "Sorryβ€”I'm having trouble right now. "
109
+ "Please try again shortly."
110
+ )
111
+
112
+ messages.append({"role": "assistant", "content": reply})
113
+ return messages
114
+
115
+ # ─── Gradio UI (ChatInterface handles Send & Clear) ────────────────────────
116
+ gr.ChatInterface(
117
+ fn=chat_fn,
118
+ title="SchoolSpiritΒ AI Chat",
119
+ theme=gr.themes.Soft(primary_hue="blue"),
120
+ system_prompt=SYSTEM_MSG,
121
+ chatbot=gr.Chatbot(height=480),
122
+ type="messages", # modern, future‑proof format
123
+ ).launch()