SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
47f8251 verified
raw
history blame
7.65 kB
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ---------------------------------------------------------------------------
# 0. Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except Exception:
pass # tolerate logging failures
# ---------------------------------------------------------------------------
# 1. Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CONTEXT_TOKENS = 1800
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the official digital mascot of "
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
"turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Friendly, concise (≤4 sentences unless prompted).\n"
"• No personal data collection; no medical/legal/financial advice.\n"
"• If uncertain, admit it & suggest human follow‑up.\n"
"• Avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
# ---------------------------------------------------------------------------
# 2. Load tokenizer + model (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
return_full_text=False,
streaming=True, # ← enable token-by-token streaming
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc:
MODEL_ERR = f"Model load error: {exc}"
generator = None
log(MODEL_ERR)
# ---------------------------------------------------------------------------
# 3. Helper: build prompt under token budget (with fallback on error)
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
try:
def render(msg):
if msg["role"] == "system":
return msg["content"]
prefix = "User:" if msg["role"] == "user" else "AI:"
return f"{prefix} {msg['content']}"
system_msg = next(m for m in raw_history if m["role"] == "system")
convo = [m for m in raw_history if m["role"] != "system"]
while True:
parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
break
convo = convo[2:]
return "\n".join(parts)
except Exception:
log("Error building prompt:\n" + traceback.format_exc())
# Fallback: include system + last two messages
sys_text = next((m["content"] for m in raw_history if m["role"]=="system"), "")
tail = [m for m in raw_history if m["role"]!="system"][-2:]
fallback = sys_text + "\n" + "\n".join(
f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" for m in tail
) + "\nAI:"
return fallback
# ---------------------------------------------------------------------------
# 4. Chat callback (streaming generator + robust error handling)
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
user_msg = strip(user_msg or "")
if not user_msg:
yield display_history, state
return
if len(user_msg) > MAX_INPUT_CH:
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
yield display_history, state
return
if MODEL_ERR:
display_history.append((user_msg, MODEL_ERR))
yield display_history, state
return
try:
# record user
state["raw"].append({"role": "user", "content": user_msg})
display_history.append((user_msg, ""))
prompt = build_prompt(state["raw"])
start = time.time()
partial = ""
# stream chunks
for chunk in generator(prompt):
try:
new_text = strip(chunk.get("generated_text", ""))
if "User:" in new_text:
new_text = new_text.split("User:", 1)[0].strip()
partial += new_text
display_history[-1] = (user_msg, partial)
yield display_history, state
except Exception:
log("Malformed chunk:\n" + traceback.format_exc())
continue
# finalize
full_reply = display_history[-1][1]
state["raw"].append({"role": "assistant", "content": full_reply})
log(f"Reply in {time.time() - start:.2f}s ({len(full_reply)} chars)")
except Exception:
log("Unexpected chat_fn error:\n" + traceback.format_exc())
err = "Apologies—an internal error occurred. Please try again."
display_history[-1] = (user_msg, err)
state["raw"].append({"role": "assistant", "content": err})
yield display_history, state
# ---------------------------------------------------------------------------
# 5. Launch Gradio Blocks UI
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
chatbot = gr.Chatbot(
value=[{"role":"assistant","content":WELCOME_MSG}],
height=480,
label="SchoolSpirit AI",
type="messages", # use the new messages format
)
state = gr.State(
{"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME_MSG},
]}
)
with gr.Row():
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
send_btn = gr.Button("Send", variant="primary")
# Use streaming=True (not stream) per Gradio API
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
if __name__ == "__main__":
try:
demo.launch()
except Exception:
log("UI launch error:\n" + traceback.format_exc())