phanerozoic commited on
Commit
502a6b6
Β·
verified Β·
1 Parent(s): ef0a942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -240
app.py CHANGED
@@ -1,285 +1,151 @@
1
- # ────────────────────────────────────────────────────────────────────────────
2
- # SchoolSpiritΒ AI Chat – robust edition
3
- # ────────────────────────────────────────────────────────────────────────────
4
- # β€’ FP‑16 GPU load β†’ CPU float32 fallback
5
- # β€’ Streaming responses with retry
6
- # β€’ Token‑aware context trimming (keeps within model window)
7
- # β€’ One‑time system + welcome message (no duplication)
8
- # β€’ Extensive logging
9
- # ────────────────────────────────────────────────────────────────────────────
10
- from __future__ import annotations
11
-
12
- import asyncio
13
- import datetime as _dt
14
- import os
15
- import re
16
- import time
17
- import traceback
18
- from dataclasses import dataclass
19
- from pathlib import Path
20
- from typing import Any, Dict, List, Tuple
21
-
22
  import gradio as gr
23
- import torch
24
- from transformers import (
25
- AutoModelForCausalLM,
26
- AutoTokenizer,
27
- GenerationConfig,
28
- TextIteratorStreamer,
29
- pipeline,
30
- )
31
  from transformers.utils import logging as hf_logging
32
 
33
- # ────────────────────────────────────────────────────────────────────────────
34
- # 0. ENV / LOGGING
35
- # ────────────────────────────────────────────────────────────────────────────
36
  os.environ["HF_HOME"] = "/data/.huggingface"
37
- LOG_FILE = Path("/data/requests.log")
38
- LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
39
 
40
 
41
- def log(line: str) -> None:
42
- ts = _dt.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
43
- entry = f"[{ts}] {line}"
44
- print(entry, flush=True)
45
  try:
46
- with LOG_FILE.open("a") as f:
47
- f.write(entry + "\n")
48
- except Exception:
49
  pass
50
 
51
 
52
- hf_logging.set_verbosity_error()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # ────────────────────────────────────────────────────────────────────────────
55
- # 1. CONFIG
56
- # ────────────────────────────────────────────────────────────────────────────
57
- @dataclass
58
- class Config:
59
- MODEL_ID: str = "ibm-granite/granite-3.3-2b-instruct"
60
- MAX_MODEL_TOKENS: int = 2048
61
- MAX_NEW_TOKENS: int = 64
62
- TEMPERATURE: float = 0.6
63
- TOP_P: float = 0.9
64
- MAX_INPUT_CH: int = 300
65
- CONTEXT_MARGIN: int = 128 # leave room for assistant completion
66
- STREAMING_CHUNK: float = 0.05 # seconds
67
- SYSTEM_PROMPT: str = (
68
- "You are **SchoolSpiritΒ AI**, the digital mascot for SchoolSpiritΒ AIΒ LLC, "
69
- "founded by CharlesΒ Norton inΒ 2025. The company installs on‑prem AI chat "
70
- "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
71
- "K‑12 schools.\n\n"
72
- "GUIDELINES:\n"
73
- "β€’ Warm, encouraging tone for students, parents, staff.\n"
74
- "β€’ Replies ≀ 4 sentences unless asked for detail.\n"
75
- "β€’ If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
76
- "β€’ No personal‑data collection or sensitive advice.\n"
77
- "β€’ No profanity, politics, or mature themes."
78
- )
79
- WELCOME_MSG: str = "Welcome to SchoolSpiritΒ AI! Do you have any questions?"
80
 
 
 
81
 
82
- CFG = Config()
83
 
84
- # ────────────────────────────────────────────────────────────────────────────
85
- # 2. LOAD MODEL (GPUΒ FP‑16 β†’ CPU fallback)
86
- # ────────────────────────────────────────────────────────────────────────────
87
- def load_pipeline() -> pipeline:
 
88
  log("Loading tokenizer …")
89
- tok = AutoTokenizer.from_pretrained(CFG.MODEL_ID)
90
-
91
- use_gpu = torch.cuda.is_available()
92
- dtype = torch.float16 if use_gpu else torch.float32
93
- log(f"{'GPU' if use_gpu else 'CPU'} detected β†’ dtype {dtype}")
94
-
95
- model = AutoModelForCausalLM.from_pretrained(
96
- CFG.MODEL_ID,
97
- device_map="auto" if use_gpu else "cpu",
98
- torch_dtype=dtype,
99
- low_cpu_mem_usage=not use_gpu,
100
- )
101
 
102
- gen_cfg = GenerationConfig(
103
- max_new_tokens=CFG.MAX_NEW_TOKENS,
104
- temperature=CFG.TEMPERATURE,
105
- top_p=CFG.TOP_P,
106
- )
 
 
 
 
 
107
 
108
- pipe = pipeline(
109
  "text-generation",
110
  model=model,
111
  tokenizer=tok,
112
- generation_config=gen_cfg,
 
 
113
  )
114
- pipe.tokenizer.padding_side = "left"
115
- pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
116
- log("Model & pipeline loaded βœ”")
117
- return pipe
118
-
119
-
120
- try:
121
- PIPE = load_pipeline()
122
  MODEL_ERR = None
 
123
  except Exception as exc: # noqa: BLE001
124
- MODEL_ERR = str(exc)
125
- log(f"Model load error: {exc}")
126
-
127
- # ────────────────────────────────────────────────────────────────────────────
128
- # 3. HELPER FUNCTIONS
129
- # ────────────────────────────────────────────────────────────────────────────
130
- _tokenizer = PIPE.tokenizer if PIPE else None
131
- _strip = lambda s: re.sub(r"\s+", " ", s.strip())
132
-
133
-
134
- def build_prompt(raw: List[Dict[str, str]]) -> str:
135
- """
136
- raw: [{'role':'system'|'user'|'assistant', 'content': str}, ...]
137
- """
138
- lines: List[str] = []
139
- for m in raw:
140
- if m["role"] == "system":
141
- lines.append(m["content"])
142
- else:
143
- prefix = "User" if m["role"] == "user" else "AI"
144
- lines.append(f"{prefix}: {m['content']}")
145
- lines.append("AI:")
146
- return "\n".join(lines)
147
-
148
-
149
- def trim_to_window(raw: List[Dict[str, str]]) -> List[Dict[str, str]]:
150
- """
151
- Trim raw history so total tokens <= model window - margin.
152
- Always keep the initial system message.
153
- """
154
- if not PIPE:
155
- return raw
156
- max_total = CFG.MAX_MODEL_TOKENS - CFG.CONTEXT_MARGIN
157
- while True:
158
- toks = len(_tokenizer.encode(build_prompt(raw)))
159
- if toks <= max_total or len(raw) <= 2:
160
- return raw
161
- # Remove second message (first non‑system) then loop
162
- raw.pop(1)
163
-
164
-
165
- # ────────────────────────────────────────────────────────────────────────────
166
- # 4. CHAT HANDLER
167
- # ────────────────────────────────────────────────────────────────────────────
168
- async def generate_stream(prompt: str):
169
- """
170
- Yields partial text chunks for streaming.
171
- """
172
- streamer = TextIteratorStreamer(
173
- PIPE.tokenizer, skip_prompt=True, skip_special_tokens=True
174
- )
175
- gen_kwargs = dict(prompt, streamer=streamer)
176
- loop = asyncio.get_event_loop()
177
- task = loop.run_in_executor(None, PIPE.model.generate, **gen_kwargs)
178
-
179
- # Stream chunks
180
- async for token in streamer:
181
- yield token
182
- await asyncio.sleep(CFG.STREAMING_CHUNK)
183
- await task # ensure generation done
184
 
185
 
186
- def respond(
187
- user_msg: str, chat_hist: List[Tuple[str, str]], state: Dict[str, Any]
188
- ) -> Tuple[List[Tuple[str, str]], Dict[str, Any]]:
 
189
  """
190
- Gradio synchronous wrapper that kicks off async generation.
 
 
191
  """
192
  if MODEL_ERR:
193
- chat_hist.append((user_msg, MODEL_ERR))
194
- return chat_hist, state
195
 
196
- user_msg = _strip(user_msg or "")
197
  if not user_msg:
198
- chat_hist.append((user_msg, "Please type something."))
199
- return chat_hist, state
200
- if len(user_msg) > CFG.MAX_INPUT_CH:
201
- chat_hist.append(
202
- (user_msg, f"Message too long (>{CFG.MAX_INPUT_CH} chars).")
203
- )
204
- return chat_hist, state
205
-
206
- raw = state["raw"]
207
- raw.append({"role": "user", "content": user_msg})
208
- raw = trim_to_window(raw)
209
- prompt = build_prompt(raw)
210
-
211
- # Streaming generation
212
- streamer = TextIteratorStreamer(
213
- PIPE.tokenizer, skip_prompt=True, skip_special_tokens=True
214
- )
215
- gen_task = PIPE.model.generate(
216
- PIPE.tokenizer(prompt, return_tensors="pt").to(PIPE.model.device)["input_ids"],
217
- streamer=streamer,
218
- max_new_tokens=CFG.MAX_NEW_TOKENS,
219
- temperature=CFG.TEMPERATURE,
220
- top_p=CFG.TOP_P,
221
  )
222
 
223
- reply = ""
224
- for token in streamer:
225
- reply += token
226
- chat_hist[-1] = (user_msg, reply)
227
- yield chat_hist, state
 
 
228
 
229
- raw.append({"role": "assistant", "content": reply})
230
- state["raw"] = raw
231
- yield chat_hist, state
 
 
232
 
233
 
234
- # ────────────────────────────────────────────────────────────────────────────
235
- # 5. LAUNCH UI (Gradio Blocks)
236
- # ────────────────────────────────────────────────────────────────────────────
237
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
238
- gr.Markdown("# 🏫 SchoolSpirit AI Chat")
239
-
240
  chatbot = gr.Chatbot(
241
- value=[("", CFG.WELCOME_MSG)],
242
- height=480,
243
- label="SchoolSpiritΒ AI",
244
- )
245
-
246
- state = gr.State(
247
- {"raw": [{"role": "system", "content": CFG.SYSTEM_PROMPT}]}
248
  )
249
-
250
  with gr.Row():
251
  txt = gr.Textbox(
252
- scale=4,
253
- placeholder="Ask me anything about SchoolSpiritΒ AI …",
254
- show_label=False,
255
  )
256
  send = gr.Button("Send", variant="primary")
257
 
258
- # Bind both button click and ENTER keypress
259
- for trigger in (send, txt):
260
- trigger.click(
261
- respond,
262
- inputs=[txt, chatbot, state],
263
- outputs=[chatbot, state],
264
- ).then(
265
- lambda: "",
266
- None,
267
- txt,
268
- ) # clear textbox
269
-
270
- demo.load(lambda: None) # dummy to ensure Blocks builds
271
-
272
- # ---------------------------------------------------------------------------
273
- # Graceful shutdown (for HF Space restarts)
274
- # ---------------------------------------------------------------------------
275
- def _shutdown(*_):
276
- log("Space shutting down …")
277
-
278
-
279
- import atexit, signal # noqa: E402
280
-
281
- atexit.register(_shutdown)
282
- signal.signal(signal.SIGTERM, lambda *_: _shutdown())
283
- signal.signal(signal.SIGINT, lambda *_: _shutdown())
284
 
285
  demo.launch()
 
1
+ import os, re, time, datetime, traceback, torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
 
4
  from transformers.utils import logging as hf_logging
5
 
6
+ # -------------------------------------------------------------------
7
+ # 1. Logging helpers
8
+ # -------------------------------------------------------------------
9
  os.environ["HF_HOME"] = "/data/.huggingface"
10
+ LOG_FILE = "/data/requests.log"
 
11
 
12
 
13
+ def log(msg: str):
14
+ ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
15
+ line = f"[{ts}] {msg}"
16
+ print(line, flush=True)
17
  try:
18
+ with open(LOG_FILE, "a") as f:
19
+ f.write(line + "\n")
20
+ except FileNotFoundError:
21
  pass
22
 
23
 
24
+ # -------------------------------------------------------------------
25
+ # 2. Configuration
26
+ # -------------------------------------------------------------------
27
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
28
+ MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
29
+
30
+ SYSTEM_MSG = (
31
+ "You are **SchoolSpiritΒ AI**, the digital mascot for SchoolSpiritΒ AIΒ LLC, "
32
+ "founded by CharlesΒ Norton inΒ 2025. The company installs on‑prem AI chat "
33
+ "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
34
+ "K‑12 schools.\n\n"
35
+ "GUIDELINES:\n"
36
+ "β€’ Warm, encouraging tone for students, parents, staff.\n"
37
+ "β€’ Replies ≀ 4 sentences unless asked for detail.\n"
38
+ "β€’ If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
39
+ "β€’ No personal‑data collection or sensitive advice.\n"
40
+ "β€’ No profanity, politics, or mature themes."
41
+ )
42
+ WELCOME_MSG = "Welcome to SchoolSpiritΒ AI! Do you have any questions?"
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def strip(s: str) -> str:
46
+ return re.sub(r"\s+", " ", s.strip())
47
 
 
48
 
49
+ # -------------------------------------------------------------------
50
+ # 3. Load model (GPU FP‑16 β†’ CPU fallback)
51
+ # -------------------------------------------------------------------
52
+ hf_logging.set_verbosity_error()
53
+ try:
54
  log("Loading tokenizer …")
55
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ if torch.cuda.is_available():
58
+ log("GPU detected β†’ FP‑16")
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_ID, device_map="auto", torch_dtype=torch.float16
61
+ )
62
+ else:
63
+ log("CPU fallback")
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
66
+ )
67
 
68
+ gen = pipeline(
69
  "text-generation",
70
  model=model,
71
  tokenizer=tok,
72
+ max_new_tokens=MAX_TOKENS,
73
+ do_sample=True,
74
+ temperature=0.6,
75
  )
 
 
 
 
 
 
 
 
76
  MODEL_ERR = None
77
+ log("Model loaded βœ”")
78
  except Exception as exc: # noqa: BLE001
79
+ MODEL_ERR, gen = f"Model load error: {exc}", None
80
+ log(MODEL_ERR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
+ # -------------------------------------------------------------------
84
+ # 4. Chat callback
85
+ # -------------------------------------------------------------------
86
+ def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
87
  """
88
+ history: list of (user, assistant) tuples (Gradio default)
89
+ state : dict carrying system_prompt + raw_history for the model
90
+ Returns updated history (for UI) and state (for next round)
91
  """
92
  if MODEL_ERR:
93
+ return history + [(user_msg, MODEL_ERR)], state
 
94
 
95
+ user_msg = strip(user_msg or "")
96
  if not user_msg:
97
+ return history + [(user_msg, "Please type something.")], state
98
+ if len(user_msg) > MAX_INPUT_CH:
99
+ warn = f"Message too long (>{MAX_INPUT_CH} chars)."
100
+ return history + [(user_msg, warn)], state
101
+
102
+ # ------------------------------------------------ Prompt assembly
103
+ raw_hist = state.get("raw", [])
104
+ raw_hist.append({"role": "user", "content": user_msg})
105
+ # keep system + last N exchanges
106
+ convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
107
+ raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
108
+
109
+ prompt = "\n".join(
110
+ [
111
+ m["content"]
112
+ if m["role"] == "system"
113
+ else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
114
+ for m in raw_hist
115
+ ]
116
+ + ["AI:"]
 
 
 
117
  )
118
 
119
+ try:
120
+ raw = gen(prompt)[0]["generated_text"]
121
+ reply = strip(raw.split("AI:", 1)[-1])
122
+ reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
123
+ except Exception:
124
+ log("❌ Inference error:\n" + traceback.format_exc())
125
+ reply = "Sorryβ€”backend crashed. Please try again later."
126
 
127
+ # ------------------------------------------------ Update state + UI history
128
+ raw_hist.append({"role": "assistant", "content": reply})
129
+ state["raw"] = raw_hist
130
+ history.append((user_msg, reply))
131
+ return history, state
132
 
133
 
134
+ # -------------------------------------------------------------------
135
+ # 5. Launch
136
+ # -------------------------------------------------------------------
137
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
 
 
138
  chatbot = gr.Chatbot(
139
+ value=[("", WELCOME_MSG)], height=480, label="SchoolSpiritΒ AI"
 
 
 
 
 
 
140
  )
141
+ state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
142
  with gr.Row():
143
  txt = gr.Textbox(
144
+ scale=4, placeholder="Type your question here...", show_label=False
 
 
145
  )
146
  send = gr.Button("Send", variant="primary")
147
 
148
+ send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
149
+ txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  demo.launch()