phanerozoic commited on
Commit
2e445c2
·
verified ·
1 Parent(s): e63a535

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -65
app.py CHANGED
@@ -1,14 +1,15 @@
1
- ##############################################################################
2
- # SchoolSpirit AI Chat – full‑context, duplicate‑free implementation
3
- ##############################################################################
4
-
5
- import os, re, time, datetime, traceback, torch
 
6
  import gradio as gr
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from transformers.utils import logging as hf_logging
9
 
10
  # ---------------------------------------------------------------------------
11
- # 0. Basic logging
12
  # ---------------------------------------------------------------------------
13
  os.environ["HF_HOME"] = "/data/.huggingface"
14
  LOG_FILE = "/data/requests.log"
@@ -26,24 +27,24 @@ def log(msg: str):
26
 
27
 
28
  # ---------------------------------------------------------------------------
29
- # 1. Configuration
30
  # ---------------------------------------------------------------------------
31
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2B local model
32
- CONTEXT_TOKENS = 1800 # keep prompt below this many tokens
33
  MAX_NEW_TOKENS = 64
34
  TEMPERATURE = 0.6
35
- MAX_INPUT_CH = 300
36
 
37
  SYSTEM_MSG = (
38
- "You are **SchoolSpirit AI**, digital mascot of SchoolSpirit AI LLC, "
39
- "founded by Charles Norton in 2025. The company deploys on‑prem AI chat "
40
- "mascots, fine‑tunes language models, and ships turnkey GPU servers to "
41
- "K‑12 schools.\n\n"
42
  "RULES:\n"
43
- "• Respond warmly and concisely (≤4 sentences unless asked for detail).\n"
44
- "• No personaldata collection; no medical/legal/financial advice.\n"
45
- "• Admit uncertainty and offer human follow‑up if unsure.\n"
46
- "• Avoid profanity, politics, mature themes."
47
  )
48
  WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
49
 
@@ -51,19 +52,20 @@ strip = lambda s: re.sub(r"\s+", " ", s.strip())
51
 
52
 
53
  # ---------------------------------------------------------------------------
54
- # 2. Load tokenizer + model (GPU FP‑16 → CPU fallback)
55
  # ---------------------------------------------------------------------------
56
  hf_logging.set_verbosity_error()
57
  try:
 
58
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
59
 
60
  if torch.cuda.is_available():
61
- log("GPU foundloading model in FP‑16")
62
  model = AutoModelForCausalLM.from_pretrained(
63
  MODEL_ID, device_map="auto", torch_dtype=torch.float16
64
  )
65
  else:
66
- log("No GPU → CPU load")
67
  model = AutoModelForCausalLM.from_pretrained(
68
  MODEL_ID,
69
  device_map="cpu",
@@ -76,11 +78,12 @@ try:
76
  model=model,
77
  tokenizer=tokenizer,
78
  max_new_tokens=MAX_NEW_TOKENS,
79
- temperature=TEMPERATURE,
80
  do_sample=True,
 
 
81
  )
82
  MODEL_ERR = None
83
- log("Model ready ✔")
84
  except Exception as exc:
85
  MODEL_ERR = f"Model load error: {exc}"
86
  generator = None
@@ -88,77 +91,79 @@ except Exception as exc:
88
 
89
 
90
  # ---------------------------------------------------------------------------
91
- # 3. Prompt builder that respects a token budget
92
  # ---------------------------------------------------------------------------
93
- def build_prompt(history_raw: list[dict]) -> str:
94
  """
95
- Accepts full message list. Drops oldest user+assistant pairs until the
96
- prompt’s token length is CONTEXT_TOKENS.
97
  """
98
  def render(msg):
99
  if msg["role"] == "system":
100
- return f"System: {msg['content']}"
101
  prefix = "User:" if msg["role"] == "user" else "AI:"
102
  return f"{prefix} {msg['content']}"
103
 
104
- # split out system + conversation
105
- system = [m for m in history_raw if m["role"] == "system"][0]
106
- convo = [m for m in history_raw if m["role"] != "system"]
107
 
 
108
  while True:
109
- prompt_parts = [render(system)] + [render(m) for m in convo] + ["AI:"]
110
- tokens = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False))
111
- if tokens <= CONTEXT_TOKENS or len(convo) <= 2:
112
  break
113
- # remove oldest user+assistant pair
114
  convo = convo[2:]
115
 
116
  return "\n".join(prompt_parts)
117
 
118
 
119
  # ---------------------------------------------------------------------------
120
- # 4. Chat callback
121
  # ---------------------------------------------------------------------------
122
- def chat_fn(user_text, chat_ui, state):
123
  """
124
- chat_ui : list[(user, assistant)] displayed by gr.Chatbot
125
- state : dict{ 'history_raw': [...] } used for prompt construction
126
  """
127
- user_text = strip(user_text or "")
128
- if not user_text:
129
- return chat_ui, state
130
 
131
- if len(user_text) > MAX_INPUT_CH:
132
- chat_ui.append((user_text, f"Input exceeds {MAX_INPUT_CH} characters."))
133
- return chat_ui, state
134
 
135
  if MODEL_ERR:
136
- chat_ui.append((user_text, MODEL_ERR))
137
- return chat_ui, state
138
 
139
- # update raw history (system already present)
140
- state["history_raw"].append({"role": "user", "content": user_text})
141
 
142
- prompt = build_prompt(state["history_raw"])
 
143
 
 
144
  try:
145
  start = time.time()
146
- out = generator(prompt)[0]["generated_text"]
147
- reply = strip(out.split("AI:", 1)[-1])
148
- reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
149
- log(f"Generated in {time.time()-start:.2f}s ({len(reply)} chars)")
150
  except Exception:
151
- log("⚠️ Inference error\n" + traceback.format_exc())
152
- reply = "Sorry  an internal error occurred. Please try again."
153
 
154
- # append to histories
155
- chat_ui.append((user_text, reply))
156
- state["history_raw"].append({"role": "assistant", "content": reply})
157
- return chat_ui, state
158
 
159
 
160
  # ---------------------------------------------------------------------------
161
- # 5. Gradio UI
162
  # ---------------------------------------------------------------------------
163
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
164
  gr.Markdown("### SchoolSpirit AI Chat")
@@ -170,19 +175,24 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
170
  )
171
 
172
  state = gr.State(
173
- {"history_raw": [{"role": "system", "content": SYSTEM_MSG}]}
 
 
 
 
 
174
  )
175
 
176
  with gr.Row():
177
  txt = gr.Textbox(
178
  placeholder="Type your question here…",
179
  show_label=False,
180
- lines=1,
181
  scale=4,
 
182
  )
183
- send = gr.Button("Send", variant="primary")
184
 
185
- send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
186
  txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
187
 
188
  demo.launch()
 
1
+ import os
2
+ import re
3
+ import time
4
+ import datetime
5
+ import traceback
6
+ import torch
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from transformers.utils import logging as hf_logging
10
 
11
  # ---------------------------------------------------------------------------
12
+ # 0. Paths & basic logging helper
13
  # ---------------------------------------------------------------------------
14
  os.environ["HF_HOME"] = "/data/.huggingface"
15
  LOG_FILE = "/data/requests.log"
 
27
 
28
 
29
  # ---------------------------------------------------------------------------
30
+ # 1. Configuration constants
31
  # ---------------------------------------------------------------------------
32
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2 B model fits Spaces
33
+ CONTEXT_TOKENS = 1800 # leave head‑room for reply inside 2k window
34
  MAX_NEW_TOKENS = 64
35
  TEMPERATURE = 0.6
36
+ MAX_INPUT_CH = 300 # UI safeguard
37
 
38
  SYSTEM_MSG = (
39
+ "You are **SchoolSpirit AI**, the official digital mascot of "
40
+ "SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
41
+ "deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
42
+ "turnkey GPU servers to K‑12 schools.\n\n"
43
  "RULES:\n"
44
+ "• Friendly, concise (≤4 sentences unless prompted).\n"
45
+ "• No personal data collection; no medical/legal/financial advice.\n"
46
+ "• If uncertain, admit it & suggest human follow‑up.\n"
47
+ "• avoid profanity, politics, mature themes."
48
  )
49
  WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
50
 
 
52
 
53
 
54
  # ---------------------------------------------------------------------------
55
+ # 2. Load tokenizer + model (GPU FP‑16 → CPU)
56
  # ---------------------------------------------------------------------------
57
  hf_logging.set_verbosity_error()
58
  try:
59
+ log("Loading tokenizer …")
60
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
61
 
62
  if torch.cuda.is_available():
63
+ log("GPU detected → FP‑16")
64
  model = AutoModelForCausalLM.from_pretrained(
65
  MODEL_ID, device_map="auto", torch_dtype=torch.float16
66
  )
67
  else:
68
+ log("CPU fallback")
69
  model = AutoModelForCausalLM.from_pretrained(
70
  MODEL_ID,
71
  device_map="cpu",
 
78
  model=model,
79
  tokenizer=tokenizer,
80
  max_new_tokens=MAX_NEW_TOKENS,
 
81
  do_sample=True,
82
+ temperature=TEMPERATURE,
83
+ return_full_text=False, # ← only return the newly generated text
84
  )
85
  MODEL_ERR = None
86
+ log("Model loaded ✔")
87
  except Exception as exc:
88
  MODEL_ERR = f"Model load error: {exc}"
89
  generator = None
 
91
 
92
 
93
  # ---------------------------------------------------------------------------
94
+ # 3. Helper: build prompt under token budget
95
  # ---------------------------------------------------------------------------
96
+ def build_prompt(raw_history: list[dict]) -> str:
97
  """
98
+ raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...]
99
+ Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS
100
  """
101
  def render(msg):
102
  if msg["role"] == "system":
103
+ return msg["content"]
104
  prefix = "User:" if msg["role"] == "user" else "AI:"
105
  return f"{prefix} {msg['content']}"
106
 
107
+ # always include system
108
+ system_msg = [msg for msg in raw_history if msg["role"] == "system"][0]
109
+ convo = [m for m in raw_history if m["role"] != "system"]
110
 
111
+ # iterative trim
112
  while True:
113
+ prompt_parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
114
+ token_len = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False))
115
+ if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
116
  break
117
+ # drop oldest user+assistant pair
118
  convo = convo[2:]
119
 
120
  return "\n".join(prompt_parts)
121
 
122
 
123
  # ---------------------------------------------------------------------------
124
+ # 4. Chat callback
125
  # ---------------------------------------------------------------------------
126
+ def chat_fn(user_msg: str, display_history: list, state: dict):
127
  """
128
+ display_history : list[tuple[str,str]] for UI
129
+ state["raw"] : list[dict] for prompting
130
  """
131
+ user_msg = strip(user_msg or "")
132
+ if not user_msg:
133
+ return display_history, state
134
 
135
+ if len(user_msg) > MAX_INPUT_CH:
136
+ display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
137
+ return display_history, state
138
 
139
  if MODEL_ERR:
140
+ display_history.append((user_msg, MODEL_ERR))
141
+ return display_history, state
142
 
143
+ # --- Update raw history
144
+ state["raw"].append({"role": "user", "content": user_msg})
145
 
146
+ # --- Build prompt within token budget
147
+ prompt = build_prompt(state["raw"])
148
 
149
+ # --- Generate
150
  try:
151
  start = time.time()
152
+ result = generator(prompt)[0]
153
+ reply = strip(result["generated_text"])
154
+ log(f"Reply in {time.time() - start:.2f}s ({len(reply)} chars)")
 
155
  except Exception:
156
+ log("Inference error:\n" + traceback.format_exc())
157
+ reply = "Apologies—an internal error occurred. Please try again."
158
 
159
+ # --- Append assistant reply to both histories
160
+ display_history.append((user_msg, reply))
161
+ state["raw"].append({"role": "assistant", "content": reply})
162
+ return display_history, state
163
 
164
 
165
  # ---------------------------------------------------------------------------
166
+ # 5. Launch Gradio Blocks UI
167
  # ---------------------------------------------------------------------------
168
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
169
  gr.Markdown("### SchoolSpirit AI Chat")
 
175
  )
176
 
177
  state = gr.State(
178
+ {
179
+ "raw": [
180
+ {"role": "system", "content": SYSTEM_MSG},
181
+ {"role": "assistant", "content": WELCOME_MSG},
182
+ ]
183
+ }
184
  )
185
 
186
  with gr.Row():
187
  txt = gr.Textbox(
188
  placeholder="Type your question here…",
189
  show_label=False,
 
190
  scale=4,
191
+ lines=1,
192
  )
193
+ send_btn = gr.Button("Send", variant="primary")
194
 
195
+ send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
196
  txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
197
 
198
  demo.launch()