phanerozoic commited on
Commit
ef0a942
Β·
verified Β·
1 Parent(s): d672735

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -100
app.py CHANGED
@@ -1,138 +1,285 @@
1
- import os, re, time, datetime, traceback, torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
 
4
  from transformers.utils import logging as hf_logging
5
 
6
- # ---------- Logging ---------------------------------------------------------
 
 
7
  os.environ["HF_HOME"] = "/data/.huggingface"
8
- LOG_FILE = "/data/requests.log"
 
9
 
10
 
11
- def log(msg: str):
12
- ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
13
- line = f"[{ts}] {msg}"
14
- print(line, flush=True)
15
  try:
16
- with open(LOG_FILE, "a") as f:
17
- f.write(line + "\n")
18
- except FileNotFoundError:
19
  pass
20
 
21
 
22
- # ---------- Config ----------------------------------------------------------
23
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
24
- MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
25
-
26
- SYSTEM_MSG = (
27
- "You are **SchoolSpiritΒ AI**, the digital mascot for SchoolSpiritΒ AIΒ LLC, "
28
- "founded by CharlesΒ Norton inΒ 2025. The company installs on‑prem AI chat "
29
- "mascots, offers custom fine‑tuning of language models, and ships turnkey "
30
- "GPU hardware to K‑12 schools.\n\n"
31
- "GUIDELINES:\n"
32
- "β€’ Warm, encouraging tone for students, parents, staff.\n"
33
- "β€’ Replies ≀ 4 sentences unless asked for detail.\n"
34
- "β€’ If unsure/out‑of‑scope: say so & suggest human follow‑up.\n"
35
- "β€’ No personal‑data collection or sensitive advice.\n"
36
- "β€’ No profanity, politics, or mature themes."
37
- )
38
- WELCOME_MSG = "Welcome to SchoolSpiritΒ AI! Do you have any questions?"
39
-
40
- # ---------- Model load (GPU FP‑16 β†’ CPU fallback) ---------------------------
41
  hf_logging.set_verbosity_error()
42
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  log("Loading tokenizer …")
44
- tok = AutoTokenizer.from_pretrained(MODEL_ID)
45
 
46
- if torch.cuda.is_available():
47
- log("GPU detected β†’ FP‑16")
48
- model = AutoModelForCausalLM.from_pretrained(
49
- MODEL_ID, device_map="auto", torch_dtype=torch.float16
50
- )
51
- else:
52
- log("CPU fallback")
53
- model = AutoModelForCausalLM.from_pretrained(
54
- MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
55
- )
56
 
57
- gen = pipeline(
 
 
 
 
 
 
58
  "text-generation",
59
  model=model,
60
  tokenizer=tok,
61
- max_new_tokens=MAX_TOKENS,
62
- do_sample=True,
63
- temperature=0.6,
64
  )
 
 
 
 
 
 
 
 
65
  MODEL_ERR = None
66
- log("Model loaded βœ”")
67
  except Exception as exc: # noqa: BLE001
68
- MODEL_ERR, gen = f"Model load error: {exc}", None
69
- log(MODEL_ERR)
 
 
 
 
 
 
70
 
71
- clean = lambda t: re.sub(r"\s+", " ", t.strip()) or "…"
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
- def trim(hist: list):
75
- """keep system + last N user/AI pairs"""
76
- sys = [m for m in hist if m["role"] == "system"]
77
- convo = [m for m in hist if m["role"] != "system"]
78
- return sys + convo[-MAX_TURNS * 2 :]
 
 
 
 
 
 
 
 
 
79
 
80
 
81
- # ---------- Chat callback ---------------------------------------------------
82
- def chat_fn(user_msg: str, history: list):
 
 
83
  """
84
- history: list[dict] like [{'role':'assistant','content':...}, ...]
85
- Return -> reply_str (Gradio appends it as assistant msg)
86
  """
87
- log(f"User sent {len(user_msg)} chars")
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Ensure system message present exactly once
90
- if not any(m["role"] == "system" for m in history):
91
- history.insert(0, {"role": "system", "content": SYSTEM_MSG})
92
 
 
 
 
 
 
 
93
  if MODEL_ERR:
94
- return MODEL_ERR
 
95
 
96
- user_msg = clean(user_msg or "")
97
  if not user_msg:
98
- return "Please type something."
99
- if len(user_msg) > MAX_INPUT_CH:
100
- return f"Message too long (>{MAX_INPUT_CH} chars)."
101
-
102
- history.append({"role": "user", "content": user_msg})
103
- history = trim(history)
104
-
105
- prompt = "\n".join(
106
- [
107
- m["content"]
108
- if m["role"] == "system"
109
- else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
110
- for m in history
111
- ]
112
- + ["AI:"]
 
 
 
 
 
 
 
 
113
  )
114
 
115
- try:
116
- raw = gen(prompt)[0]["generated_text"]
117
- reply = clean(raw.split("AI:", 1)[-1])
118
- reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
119
- log(f"Reply {len(reply)} chars")
120
- except Exception:
121
- log("❌ Inference exception:\n" + traceback.format_exc())
122
- reply = "Sorryβ€”backend crashed. Please try again later."
 
123
 
124
- return reply
125
 
 
 
 
 
 
126
 
127
- # ---------- UI --------------------------------------------------------------
128
- gr.ChatInterface(
129
- fn=chat_fn,
130
- chatbot=gr.Chatbot(
131
  height=480,
132
- type="messages",
133
- value=[{"role": "assistant", "content": WELCOME_MSG}], # one‑time welcome
134
- ),
135
- title="SchoolSpiritΒ AI Chat",
136
- theme=gr.themes.Soft(primary_hue="blue"),
137
- type="messages",
138
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ────────────────────────────────────────────────────────────────────────────
2
+ # SchoolSpiritΒ AI Chat – robust edition
3
+ # ────────────────────────────────────────────────────────────────────────────
4
+ # β€’ FP‑16 GPU load β†’ CPU float32 fallback
5
+ # β€’ Streaming responses with retry
6
+ # β€’ Token‑aware context trimming (keeps within model window)
7
+ # β€’ One‑time system + welcome message (no duplication)
8
+ # β€’ Extensive logging
9
+ # ────────────────────────────────────────────────────────────────────────────
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import datetime as _dt
14
+ import os
15
+ import re
16
+ import time
17
+ import traceback
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Tuple
21
+
22
  import gradio as gr
23
+ import torch
24
+ from transformers import (
25
+ AutoModelForCausalLM,
26
+ AutoTokenizer,
27
+ GenerationConfig,
28
+ TextIteratorStreamer,
29
+ pipeline,
30
+ )
31
  from transformers.utils import logging as hf_logging
32
 
33
+ # ────────────────────────────────────────────────────────────────────────────
34
+ # 0. ENV / LOGGING
35
+ # ────────────────────────────────────────────────────────────────────────────
36
  os.environ["HF_HOME"] = "/data/.huggingface"
37
+ LOG_FILE = Path("/data/requests.log")
38
+ LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
39
 
40
 
41
+ def log(line: str) -> None:
42
+ ts = _dt.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
43
+ entry = f"[{ts}] {line}"
44
+ print(entry, flush=True)
45
  try:
46
+ with LOG_FILE.open("a") as f:
47
+ f.write(entry + "\n")
48
+ except Exception:
49
  pass
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  hf_logging.set_verbosity_error()
53
+
54
+ # ────────────────────────────────────────────────────────────────────────────
55
+ # 1. CONFIG
56
+ # ────────────────────────────────────────────────────────────────────────────
57
+ @dataclass
58
+ class Config:
59
+ MODEL_ID: str = "ibm-granite/granite-3.3-2b-instruct"
60
+ MAX_MODEL_TOKENS: int = 2048
61
+ MAX_NEW_TOKENS: int = 64
62
+ TEMPERATURE: float = 0.6
63
+ TOP_P: float = 0.9
64
+ MAX_INPUT_CH: int = 300
65
+ CONTEXT_MARGIN: int = 128 # leave room for assistant completion
66
+ STREAMING_CHUNK: float = 0.05 # seconds
67
+ SYSTEM_PROMPT: str = (
68
+ "You are **SchoolSpiritΒ AI**, the digital mascot for SchoolSpiritΒ AIΒ LLC, "
69
+ "founded by CharlesΒ Norton inΒ 2025. The company installs on‑prem AI chat "
70
+ "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
71
+ "K‑12 schools.\n\n"
72
+ "GUIDELINES:\n"
73
+ "β€’ Warm, encouraging tone for students, parents, staff.\n"
74
+ "β€’ Replies ≀ 4 sentences unless asked for detail.\n"
75
+ "β€’ If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
76
+ "β€’ No personal‑data collection or sensitive advice.\n"
77
+ "β€’ No profanity, politics, or mature themes."
78
+ )
79
+ WELCOME_MSG: str = "Welcome to SchoolSpiritΒ AI! Do you have any questions?"
80
+
81
+
82
+ CFG = Config()
83
+
84
+ # ────────────────────────────────────────────────────────────────────────────
85
+ # 2. LOAD MODEL (GPUΒ FP‑16 β†’ CPU fallback)
86
+ # ────────────────────────────────────────────────────────────────────────────
87
+ def load_pipeline() -> pipeline:
88
  log("Loading tokenizer …")
89
+ tok = AutoTokenizer.from_pretrained(CFG.MODEL_ID)
90
 
91
+ use_gpu = torch.cuda.is_available()
92
+ dtype = torch.float16 if use_gpu else torch.float32
93
+ log(f"{'GPU' if use_gpu else 'CPU'} detected β†’ dtype {dtype}")
94
+
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ CFG.MODEL_ID,
97
+ device_map="auto" if use_gpu else "cpu",
98
+ torch_dtype=dtype,
99
+ low_cpu_mem_usage=not use_gpu,
100
+ )
101
 
102
+ gen_cfg = GenerationConfig(
103
+ max_new_tokens=CFG.MAX_NEW_TOKENS,
104
+ temperature=CFG.TEMPERATURE,
105
+ top_p=CFG.TOP_P,
106
+ )
107
+
108
+ pipe = pipeline(
109
  "text-generation",
110
  model=model,
111
  tokenizer=tok,
112
+ generation_config=gen_cfg,
 
 
113
  )
114
+ pipe.tokenizer.padding_side = "left"
115
+ pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
116
+ log("Model & pipeline loaded βœ”")
117
+ return pipe
118
+
119
+
120
+ try:
121
+ PIPE = load_pipeline()
122
  MODEL_ERR = None
 
123
  except Exception as exc: # noqa: BLE001
124
+ MODEL_ERR = str(exc)
125
+ log(f"Model load error: {exc}")
126
+
127
+ # ────────────────────────────────────────────────────────────────────────────
128
+ # 3. HELPER FUNCTIONS
129
+ # ────────────────────────────────────────────────────────────────────────────
130
+ _tokenizer = PIPE.tokenizer if PIPE else None
131
+ _strip = lambda s: re.sub(r"\s+", " ", s.strip())
132
 
133
+
134
+ def build_prompt(raw: List[Dict[str, str]]) -> str:
135
+ """
136
+ raw: [{'role':'system'|'user'|'assistant', 'content': str}, ...]
137
+ """
138
+ lines: List[str] = []
139
+ for m in raw:
140
+ if m["role"] == "system":
141
+ lines.append(m["content"])
142
+ else:
143
+ prefix = "User" if m["role"] == "user" else "AI"
144
+ lines.append(f"{prefix}: {m['content']}")
145
+ lines.append("AI:")
146
+ return "\n".join(lines)
147
 
148
 
149
+ def trim_to_window(raw: List[Dict[str, str]]) -> List[Dict[str, str]]:
150
+ """
151
+ Trim raw history so total tokens <= model window - margin.
152
+ Always keep the initial system message.
153
+ """
154
+ if not PIPE:
155
+ return raw
156
+ max_total = CFG.MAX_MODEL_TOKENS - CFG.CONTEXT_MARGIN
157
+ while True:
158
+ toks = len(_tokenizer.encode(build_prompt(raw)))
159
+ if toks <= max_total or len(raw) <= 2:
160
+ return raw
161
+ # Remove second message (first non‑system) then loop
162
+ raw.pop(1)
163
 
164
 
165
+ # ────────────────────────────────────────────────────────────────────────────
166
+ # 4. CHAT HANDLER
167
+ # ────────────────────────────────────────────────────────────────────────────
168
+ async def generate_stream(prompt: str):
169
  """
170
+ Yields partial text chunks for streaming.
 
171
  """
172
+ streamer = TextIteratorStreamer(
173
+ PIPE.tokenizer, skip_prompt=True, skip_special_tokens=True
174
+ )
175
+ gen_kwargs = dict(prompt, streamer=streamer)
176
+ loop = asyncio.get_event_loop()
177
+ task = loop.run_in_executor(None, PIPE.model.generate, **gen_kwargs)
178
+
179
+ # Stream chunks
180
+ async for token in streamer:
181
+ yield token
182
+ await asyncio.sleep(CFG.STREAMING_CHUNK)
183
+ await task # ensure generation done
184
 
 
 
 
185
 
186
+ def respond(
187
+ user_msg: str, chat_hist: List[Tuple[str, str]], state: Dict[str, Any]
188
+ ) -> Tuple[List[Tuple[str, str]], Dict[str, Any]]:
189
+ """
190
+ Gradio synchronous wrapper that kicks off async generation.
191
+ """
192
  if MODEL_ERR:
193
+ chat_hist.append((user_msg, MODEL_ERR))
194
+ return chat_hist, state
195
 
196
+ user_msg = _strip(user_msg or "")
197
  if not user_msg:
198
+ chat_hist.append((user_msg, "Please type something."))
199
+ return chat_hist, state
200
+ if len(user_msg) > CFG.MAX_INPUT_CH:
201
+ chat_hist.append(
202
+ (user_msg, f"Message too long (>{CFG.MAX_INPUT_CH} chars).")
203
+ )
204
+ return chat_hist, state
205
+
206
+ raw = state["raw"]
207
+ raw.append({"role": "user", "content": user_msg})
208
+ raw = trim_to_window(raw)
209
+ prompt = build_prompt(raw)
210
+
211
+ # Streaming generation
212
+ streamer = TextIteratorStreamer(
213
+ PIPE.tokenizer, skip_prompt=True, skip_special_tokens=True
214
+ )
215
+ gen_task = PIPE.model.generate(
216
+ PIPE.tokenizer(prompt, return_tensors="pt").to(PIPE.model.device)["input_ids"],
217
+ streamer=streamer,
218
+ max_new_tokens=CFG.MAX_NEW_TOKENS,
219
+ temperature=CFG.TEMPERATURE,
220
+ top_p=CFG.TOP_P,
221
  )
222
 
223
+ reply = ""
224
+ for token in streamer:
225
+ reply += token
226
+ chat_hist[-1] = (user_msg, reply)
227
+ yield chat_hist, state
228
+
229
+ raw.append({"role": "assistant", "content": reply})
230
+ state["raw"] = raw
231
+ yield chat_hist, state
232
 
 
233
 
234
+ # ────────────────────────────────────────────────────────────────────────────
235
+ # 5. LAUNCH UI (Gradio Blocks)
236
+ # ────────────────────────────────────────────────────────────────────────────
237
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
238
+ gr.Markdown("# 🏫 SchoolSpirit AI Chat")
239
 
240
+ chatbot = gr.Chatbot(
241
+ value=[("", CFG.WELCOME_MSG)],
 
 
242
  height=480,
243
+ label="SchoolSpiritΒ AI",
244
+ )
245
+
246
+ state = gr.State(
247
+ {"raw": [{"role": "system", "content": CFG.SYSTEM_PROMPT}]}
248
+ )
249
+
250
+ with gr.Row():
251
+ txt = gr.Textbox(
252
+ scale=4,
253
+ placeholder="Ask me anything about SchoolSpiritΒ AI …",
254
+ show_label=False,
255
+ )
256
+ send = gr.Button("Send", variant="primary")
257
+
258
+ # Bind both button click and ENTER keypress
259
+ for trigger in (send, txt):
260
+ trigger.click(
261
+ respond,
262
+ inputs=[txt, chatbot, state],
263
+ outputs=[chatbot, state],
264
+ ).then(
265
+ lambda: "",
266
+ None,
267
+ txt,
268
+ ) # clear textbox
269
+
270
+ demo.load(lambda: None) # dummy to ensure Blocks builds
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Graceful shutdown (for HF Space restarts)
274
+ # ---------------------------------------------------------------------------
275
+ def _shutdown(*_):
276
+ log("Space shutting down …")
277
+
278
+
279
+ import atexit, signal # noqa: E402
280
+
281
+ atexit.register(_shutdown)
282
+ signal.signal(signal.SIGTERM, lambda *_: _shutdown())
283
+ signal.signal(signal.SIGINT, lambda *_: _shutdown())
284
+
285
+ demo.launch()