phanerozoic commited on
Commit
e4f3d46
·
verified ·
1 Parent(s): 8a16381

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -82
app.py CHANGED
@@ -14,7 +14,6 @@ from transformers.utils import logging as hf_logging
14
  os.environ["HF_HOME"] = "/data/.huggingface"
15
  LOG_FILE = "/data/requests.log"
16
 
17
-
18
  def log(msg: str):
19
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
20
  line = f"[{ts}] {msg}"
@@ -22,18 +21,18 @@ def log(msg: str):
22
  try:
23
  with open(LOG_FILE, "a") as f:
24
  f.write(line + "\n")
25
- except FileNotFoundError:
 
26
  pass
27
 
28
-
29
  # ---------------------------------------------------------------------------
30
  # 1. Configuration constants
31
  # ---------------------------------------------------------------------------
32
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2 B model fits Spaces
33
- CONTEXT_TOKENS = 1800 # leave head‑room for reply inside 2k window
34
  MAX_NEW_TOKENS = 64
35
  TEMPERATURE = 0.6
36
- MAX_INPUT_CH = 300 # UI safeguard
37
 
38
  SYSTEM_MSG = (
39
  "You are **SchoolSpirit AI**, the official digital mascot of "
@@ -50,7 +49,6 @@ WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
50
 
51
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
52
 
53
-
54
  # ---------------------------------------------------------------------------
55
  # 2. Load tokenizer + model (GPU FP‑16 → CPU)
56
  # ---------------------------------------------------------------------------
@@ -67,10 +65,7 @@ try:
67
  else:
68
  log("CPU fallback")
69
  model = AutoModelForCausalLM.from_pretrained(
70
- MODEL_ID,
71
- device_map="cpu",
72
- torch_dtype="auto",
73
- low_cpu_mem_usage=True,
74
  )
75
 
76
  generator = pipeline(
@@ -80,7 +75,8 @@ try:
80
  max_new_tokens=MAX_NEW_TOKENS,
81
  do_sample=True,
82
  temperature=TEMPERATURE,
83
- return_full_text=False, # ← only return the newly generated text
 
84
  )
85
  MODEL_ERR = None
86
  log("Model loaded ✔")
@@ -89,81 +85,96 @@ except Exception as exc:
89
  generator = None
90
  log(MODEL_ERR)
91
 
92
-
93
  # ---------------------------------------------------------------------------
94
- # 3. Helper: build prompt under token budget
95
  # ---------------------------------------------------------------------------
96
  def build_prompt(raw_history: list[dict]) -> str:
97
- """
98
- raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...]
99
- Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS
100
- """
101
- def render(msg):
102
- if msg["role"] == "system":
103
- return msg["content"]
104
- prefix = "User:" if msg["role"] == "user" else "AI:"
105
- return f"{prefix} {msg['content']}"
106
-
107
- # always include system
108
- system_msg = [msg for msg in raw_history if msg["role"] == "system"][0]
109
- convo = [m for m in raw_history if m["role"] != "system"]
110
-
111
- # iterative trim
112
- while True:
113
- prompt_parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
114
- token_len = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False))
115
- if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
116
- break
117
- # drop oldest user+assistant pair
118
- convo = convo[2:]
119
-
120
- return "\n".join(prompt_parts)
121
 
 
 
 
 
 
 
 
 
 
122
 
123
  # ---------------------------------------------------------------------------
124
- # 4. Chat callback
125
  # ---------------------------------------------------------------------------
126
  def chat_fn(user_msg: str, display_history: list, state: dict):
127
- """
128
- display_history : list[tuple[str,str]] for UI
129
- state["raw"] : list[dict] for prompting
130
- """
131
  user_msg = strip(user_msg or "")
 
132
  if not user_msg:
133
- return display_history, state
 
134
 
 
135
  if len(user_msg) > MAX_INPUT_CH:
136
  display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
137
- return display_history, state
 
138
 
 
139
  if MODEL_ERR:
140
  display_history.append((user_msg, MODEL_ERR))
141
- return display_history, state
142
-
143
- # --- Update raw history
144
- state["raw"].append({"role": "user", "content": user_msg})
145
 
146
- # --- Build prompt within token budget
147
- prompt = build_prompt(state["raw"])
148
-
149
- # --- Generate
150
  try:
151
- start = time.time()
152
- result = generator(prompt)[0]
153
- reply = strip(result["generated_text"])
154
- # ── NEW: truncate at any hallucinated next "User:" turn
155
- if "User:" in reply:
156
- reply = reply.split("User:", 1)[0].strip()
157
- log(f"Reply in {time.time() - start:.2f}s ({len(reply)} chars)")
158
- except Exception:
159
- log("❌ Inference error:\n" + traceback.format_exc())
160
- reply = "Apologies—an internal error occurred. Please try again."
161
 
162
- # --- Append assistant reply to both histories
163
- display_history.append((user_msg, reply))
164
- state["raw"].append({"role": "assistant", "content": reply})
165
- return display_history, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
 
 
 
 
 
 
 
167
 
168
  # ---------------------------------------------------------------------------
169
  # 5. Launch Gradio Blocks UI
@@ -178,24 +189,22 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
178
  )
179
 
180
  state = gr.State(
181
- {
182
- "raw": [
183
- {"role": "system", "content": SYSTEM_MSG},
184
- {"role": "assistant", "content": WELCOME_MSG},
185
- ]
186
- }
187
  )
188
 
189
  with gr.Row():
190
- txt = gr.Textbox(
191
- placeholder="Type your question here…",
192
- show_label=False,
193
- scale=4,
194
- lines=1,
195
- )
196
  send_btn = gr.Button("Send", variant="primary")
197
 
198
- send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
199
- txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
 
200
 
201
- demo.launch()
 
 
 
 
 
14
  os.environ["HF_HOME"] = "/data/.huggingface"
15
  LOG_FILE = "/data/requests.log"
16
 
 
17
  def log(msg: str):
18
  ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
19
  line = f"[{ts}] {msg}"
 
21
  try:
22
  with open(LOG_FILE, "a") as f:
23
  f.write(line + "\n")
24
+ except Exception:
25
+ # If logging to file fails, we still want the process to continue
26
  pass
27
 
 
28
  # ---------------------------------------------------------------------------
29
  # 1. Configuration constants
30
  # ---------------------------------------------------------------------------
31
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
32
+ CONTEXT_TOKENS = 1800
33
  MAX_NEW_TOKENS = 64
34
  TEMPERATURE = 0.6
35
+ MAX_INPUT_CH = 300
36
 
37
  SYSTEM_MSG = (
38
  "You are **SchoolSpirit AI**, the official digital mascot of "
 
49
 
50
  strip = lambda s: re.sub(r"\s+", " ", s.strip())
51
 
 
52
  # ---------------------------------------------------------------------------
53
  # 2. Load tokenizer + model (GPU FP‑16 → CPU)
54
  # ---------------------------------------------------------------------------
 
65
  else:
66
  log("CPU fallback")
67
  model = AutoModelForCausalLM.from_pretrained(
68
+ MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
 
 
 
69
  )
70
 
71
  generator = pipeline(
 
75
  max_new_tokens=MAX_NEW_TOKENS,
76
  do_sample=True,
77
  temperature=TEMPERATURE,
78
+ return_full_text=False,
79
+ streaming=True, # ← enable token-by-token streaming
80
  )
81
  MODEL_ERR = None
82
  log("Model loaded ✔")
 
85
  generator = None
86
  log(MODEL_ERR)
87
 
 
88
  # ---------------------------------------------------------------------------
89
+ # 3. Helper: build prompt under token budget (with fallback on error)
90
  # ---------------------------------------------------------------------------
91
  def build_prompt(raw_history: list[dict]) -> str:
92
+ try:
93
+ def render(msg):
94
+ if msg["role"] == "system":
95
+ return msg["content"]
96
+ prefix = "User:" if msg["role"] == "user" else "AI:"
97
+ return f"{prefix} {msg['content']}"
98
+
99
+ system_msg = next(m for m in raw_history if m["role"] == "system")
100
+ convo = [m for m in raw_history if m["role"] != "system"]
101
+
102
+ while True:
103
+ parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
104
+ token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
105
+ if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
106
+ break
107
+ convo = convo[2:]
108
+ return "\n".join(parts)
 
 
 
 
 
 
 
109
 
110
+ except Exception:
111
+ log("Error building prompt:\n" + traceback.format_exc())
112
+ # Fallback: include system + last two messages
113
+ sys_text = next((m["content"] for m in raw_history if m["role"]=="system"), "")
114
+ tail = [m for m in raw_history if m["role"]!="system"][-2:]
115
+ fallback = sys_text + "\n" + "\n".join(
116
+ f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" for m in tail
117
+ ) + "\nAI:"
118
+ return fallback
119
 
120
  # ---------------------------------------------------------------------------
121
+ # 4. Chat callback (streaming generator + robust error handling)
122
  # ---------------------------------------------------------------------------
123
  def chat_fn(user_msg: str, display_history: list, state: dict):
 
 
 
 
124
  user_msg = strip(user_msg or "")
125
+ # Yield nothing if empty input
126
  if not user_msg:
127
+ yield display_history, state
128
+ return
129
 
130
+ # Input length check
131
  if len(user_msg) > MAX_INPUT_CH:
132
  display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
133
+ yield display_history, state
134
+ return
135
 
136
+ # Model load error
137
  if MODEL_ERR:
138
  display_history.append((user_msg, MODEL_ERR))
139
+ yield display_history, state
140
+ return
 
 
141
 
 
 
 
 
142
  try:
143
+ # Add user to history and display
144
+ state["raw"].append({"role": "user", "content": user_msg})
145
+ display_history.append((user_msg, ""))
 
 
 
 
 
 
 
146
 
147
+ prompt = build_prompt(state["raw"])
148
+ start = time.time()
149
+ partial = ""
150
+
151
+ # Stream tokens as they arrive
152
+ for chunk in generator(prompt):
153
+ try:
154
+ new_text = strip(chunk.get("generated_text", ""))
155
+ # Truncate any hallucinated next-turn
156
+ if "User:" in new_text:
157
+ new_text = new_text.split("User:", 1)[0].strip()
158
+ partial += new_text
159
+ display_history[-1] = (user_msg, partial)
160
+ yield display_history, state
161
+ except Exception:
162
+ # Skip malformed chunk but keep streaming
163
+ log("Malformed chunk:\n" + traceback.format_exc())
164
+ continue
165
+
166
+ # Finalize
167
+ full_reply = display_history[-1][1]
168
+ state["raw"].append({"role": "assistant", "content": full_reply})
169
+ log(f"Reply in {time.time()-start:.2f}s ({len(full_reply)} chars)")
170
 
171
+ except Exception:
172
+ # Catch-all for any unexpected errors in chat flow
173
+ log("Unexpected chat_fn error:\n" + traceback.format_exc())
174
+ err = "Apologies—an internal error occurred. Please try again."
175
+ display_history[-1] = (user_msg, err)
176
+ state["raw"].append({"role": "assistant", "content": err})
177
+ yield display_history, state
178
 
179
  # ---------------------------------------------------------------------------
180
  # 5. Launch Gradio Blocks UI
 
189
  )
190
 
191
  state = gr.State(
192
+ {"raw": [
193
+ {"role": "system", "content": SYSTEM_MSG},
194
+ {"role": "assistant", "content": WELCOME_MSG},
195
+ ]}
 
 
196
  )
197
 
198
  with gr.Row():
199
+ txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
 
 
 
 
 
200
  send_btn = gr.Button("Send", variant="primary")
201
 
202
+ # Enable streaming updates in the UI
203
+ send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], stream=True)
204
+ txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], stream=True)
205
 
206
+ if __name__ == "__main__":
207
+ try:
208
+ demo.launch()
209
+ except Exception as exc:
210
+ log(f"UI launch error:\n{traceback.format_exc()}")