phanerozoic commited on
Commit
56e0226
·
verified ·
1 Parent(s): 47f8251

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -90
app.py CHANGED
@@ -14,6 +14,7 @@ from transformers.utils import logging as hf_logging
14
  os.environ["HF_HOME"] = "/data/.huggingface"
15
  LOG_FILE = "/data/requests.log"
16
 
 
17
  def log(msg: str):
18
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
19
  line = f"[{ts}] {msg}"
@@ -21,8 +22,9 @@ def log(msg: str):
21
  try:
22
  with open(LOG_FILE, "a") as f:
23
  f.write(line + "\n")
24
- except Exception:
25
- pass # tolerate logging failures
 
26
 
27
  # ---------------------------------------------------------------------------
28
  # 1. Configuration constants
@@ -48,6 +50,7 @@ WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
48
 
49
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
50
 
 
51
  # ---------------------------------------------------------------------------
52
  # 2. Load tokenizer + model (GPU FP‑16 → CPU)
53
  # ---------------------------------------------------------------------------
@@ -64,7 +67,10 @@ try:
64
  else:
65
  log("CPU fallback")
66
  model = AutoModelForCausalLM.from_pretrained(
67
- MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
 
 
 
68
  )
69
 
70
  generator = pipeline(
@@ -75,7 +81,6 @@ try:
75
  do_sample=True,
76
  temperature=TEMPERATURE,
77
  return_full_text=False,
78
- streaming=True, # ← enable token-by-token streaming
79
  )
80
  MODEL_ERR = None
81
  log("Model loaded ✔")
@@ -84,121 +89,88 @@ except Exception as exc:
84
  generator = None
85
  log(MODEL_ERR)
86
 
 
87
  # ---------------------------------------------------------------------------
88
- # 3. Helper: build prompt under token budget (with fallback on error)
89
  # ---------------------------------------------------------------------------
90
  def build_prompt(raw_history: list[dict]) -> str:
91
- try:
92
- def render(msg):
93
- if msg["role"] == "system":
94
- return msg["content"]
95
- prefix = "User:" if msg["role"] == "user" else "AI:"
96
- return f"{prefix} {msg['content']}"
97
-
98
- system_msg = next(m for m in raw_history if m["role"] == "system")
99
- convo = [m for m in raw_history if m["role"] != "system"]
100
-
101
- while True:
102
- parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
103
- token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
104
- if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
105
- break
106
- convo = convo[2:]
107
- return "\n".join(parts)
108
 
109
- except Exception:
110
- log("Error building prompt:\n" + traceback.format_exc())
111
- # Fallback: include system + last two messages
112
- sys_text = next((m["content"] for m in raw_history if m["role"]=="system"), "")
113
- tail = [m for m in raw_history if m["role"]!="system"][-2:]
114
- fallback = sys_text + "\n" + "\n".join(
115
- f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" for m in tail
116
- ) + "\nAI:"
117
- return fallback
118
 
119
  # ---------------------------------------------------------------------------
120
- # 4. Chat callback (streaming generator + robust error handling)
121
  # ---------------------------------------------------------------------------
122
  def chat_fn(user_msg: str, display_history: list, state: dict):
123
  user_msg = strip(user_msg or "")
124
  if not user_msg:
125
- yield display_history, state
126
- return
127
 
128
  if len(user_msg) > MAX_INPUT_CH:
129
  display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
130
- yield display_history, state
131
- return
132
 
133
  if MODEL_ERR:
134
  display_history.append((user_msg, MODEL_ERR))
135
- yield display_history, state
136
- return
137
 
138
- try:
139
- # record user
140
- state["raw"].append({"role": "user", "content": user_msg})
141
- display_history.append((user_msg, ""))
142
 
143
- prompt = build_prompt(state["raw"])
144
- start = time.time()
145
- partial = ""
146
-
147
- # stream chunks
148
- for chunk in generator(prompt):
149
- try:
150
- new_text = strip(chunk.get("generated_text", ""))
151
- if "User:" in new_text:
152
- new_text = new_text.split("User:", 1)[0].strip()
153
- partial += new_text
154
- display_history[-1] = (user_msg, partial)
155
- yield display_history, state
156
- except Exception:
157
- log("Malformed chunk:\n" + traceback.format_exc())
158
- continue
159
-
160
- # finalize
161
- full_reply = display_history[-1][1]
162
- state["raw"].append({"role": "assistant", "content": full_reply})
163
- log(f"Reply in {time.time() - start:.2f}s ({len(full_reply)} chars)")
164
 
 
 
 
 
 
 
 
 
 
165
  except Exception:
166
- log("Unexpected chat_fn error:\n" + traceback.format_exc())
167
- err = "Apologies—an internal error occurred. Please try again."
168
- display_history[-1] = (user_msg, err)
169
- state["raw"].append({"role": "assistant", "content": err})
170
- yield display_history, state
 
 
 
 
171
 
172
  # ---------------------------------------------------------------------------
173
- # 5. Launch Gradio Blocks UI
174
  # ---------------------------------------------------------------------------
175
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
176
  gr.Markdown("### SchoolSpirit AI Chat")
177
 
178
- chatbot = gr.Chatbot(
179
- value=[{"role":"assistant","content":WELCOME_MSG}],
180
- height=480,
181
- label="SchoolSpirit AI",
182
- type="messages", # use the new messages format
183
- )
184
-
185
- state = gr.State(
186
- {"raw": [
187
- {"role": "system", "content": SYSTEM_MSG},
188
- {"role": "assistant", "content": WELCOME_MSG},
189
- ]}
190
- )
191
 
192
  with gr.Row():
193
  txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
194
  send_btn = gr.Button("Send", variant="primary")
195
 
196
- # Use streaming=True (not stream) per Gradio API
197
- send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
198
- txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
199
 
200
- if __name__ == "__main__":
201
- try:
202
- demo.launch()
203
- except Exception:
204
- log("UI launch error:\n" + traceback.format_exc())
 
14
  os.environ["HF_HOME"] = "/data/.huggingface"
15
  LOG_FILE = "/data/requests.log"
16
 
17
+
18
  def log(msg: str):
19
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
20
  line = f"[{ts}] {msg}"
 
22
  try:
23
  with open(LOG_FILE, "a") as f:
24
  f.write(line + "\n")
25
+ except FileNotFoundError:
26
+ pass
27
+
28
 
29
  # ---------------------------------------------------------------------------
30
  # 1. Configuration constants
 
50
 
51
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
52
 
53
+
54
  # ---------------------------------------------------------------------------
55
  # 2. Load tokenizer + model (GPU FP‑16 → CPU)
56
  # ---------------------------------------------------------------------------
 
67
  else:
68
  log("CPU fallback")
69
  model = AutoModelForCausalLM.from_pretrained(
70
+ MODEL_ID,
71
+ device_map="cpu",
72
+ torch_dtype="auto",
73
+ low_cpu_mem_usage=True,
74
  )
75
 
76
  generator = pipeline(
 
81
  do_sample=True,
82
  temperature=TEMPERATURE,
83
  return_full_text=False,
 
84
  )
85
  MODEL_ERR = None
86
  log("Model loaded ✔")
 
89
  generator = None
90
  log(MODEL_ERR)
91
 
92
+
93
  # ---------------------------------------------------------------------------
94
+ # 3. Helper: build prompt under token budget
95
  # ---------------------------------------------------------------------------
96
  def build_prompt(raw_history: list[dict]) -> str:
97
+ system_msg = [m for m in raw_history if m["role"] == "system"][0]
98
+ convo = [m for m in raw_history if m["role"] != "system"]
99
+
100
+ def render(msg):
101
+ prefix = "User:" if msg["role"] == "user" else "AI:"
102
+ return f"{prefix} {msg['content']}" if msg["role"] != "system" else msg["content"]
103
+
104
+ while True:
105
+ parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
106
+ token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
107
+ if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
108
+ break
109
+ convo = convo[2:]
110
+ return "\n".join([system_msg["content"]] + [render(m) for m in convo] + ["AI:"])
 
 
 
111
 
 
 
 
 
 
 
 
 
 
112
 
113
  # ---------------------------------------------------------------------------
114
+ # 4. Chat callback with immediate user echo & spinner
115
  # ---------------------------------------------------------------------------
116
  def chat_fn(user_msg: str, display_history: list, state: dict):
117
  user_msg = strip(user_msg or "")
118
  if not user_msg:
119
+ return display_history, state
 
120
 
121
  if len(user_msg) > MAX_INPUT_CH:
122
  display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
123
+ return display_history, state
 
124
 
125
  if MODEL_ERR:
126
  display_history.append((user_msg, MODEL_ERR))
127
+ return display_history, state
 
128
 
129
+ # Immediately append the user message with a placeholder
130
+ display_history.append((user_msg, ""))
 
 
131
 
132
+ # Update raw history for prompt
133
+ state["raw"].append({"role": "user", "content": user_msg})
134
+ prompt = build_prompt(state["raw"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Generate the bot reply
137
+ try:
138
+ start = time.time()
139
+ out = generator(prompt)[0]["generated_text"].strip()
140
+ # Truncate any hallucinated next "User:"
141
+ if "User:" in out:
142
+ out = out.split("User:", 1)[0].strip()
143
+ reply = out
144
+ log(f"Reply in {time.time()-start:.2f}s ({len(reply)} chars)")
145
  except Exception:
146
+ log(" Inference error:\n" + traceback.format_exc())
147
+ reply = "Apologies—an internal error occurred. Please try again."
148
+
149
+ # Replace the placeholder with the actual reply
150
+ display_history[-1] = (user_msg, reply)
151
+ state["raw"].append({"role": "assistant", "content": reply})
152
+
153
+ return display_history, state
154
+
155
 
156
  # ---------------------------------------------------------------------------
157
+ # 5. Launch Gradio Blocks UI with spinner
158
  # ---------------------------------------------------------------------------
159
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
160
  gr.Markdown("### SchoolSpirit AI Chat")
161
 
162
+ chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI")
163
+ state = gr.State({"raw": [
164
+ {"role": "system", "content": SYSTEM_MSG},
165
+ {"role": "assistant", "content": WELCOME_MSG},
166
+ ]})
 
 
 
 
 
 
 
 
167
 
168
  with gr.Row():
169
  txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
170
  send_btn = gr.Button("Send", variant="primary")
171
 
172
+ # show_progress=True displays a spinner while waiting
173
+ send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
174
+ txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
175
 
176
+ demo.launch()