SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
56e0226 verified
raw
history blame
6.44 kB
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ---------------------------------------------------------------------------
# 0. Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# ---------------------------------------------------------------------------
# 1. Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CONTEXT_TOKENS = 1800
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the official digital mascot of "
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
"turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Friendly, concise (≤4 sentences unless prompted).\n"
"• No personal data collection; no medical/legal/financial advice.\n"
"• If uncertain, admit it & suggest human follow‑up.\n"
"• Avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
# ---------------------------------------------------------------------------
# 2. Load tokenizer + model (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cpu",
torch_dtype="auto",
low_cpu_mem_usage=True,
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
return_full_text=False,
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc:
MODEL_ERR = f"Model load error: {exc}"
generator = None
log(MODEL_ERR)
# ---------------------------------------------------------------------------
# 3. Helper: build prompt under token budget
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
system_msg = [m for m in raw_history if m["role"] == "system"][0]
convo = [m for m in raw_history if m["role"] != "system"]
def render(msg):
prefix = "User:" if msg["role"] == "user" else "AI:"
return f"{prefix} {msg['content']}" if msg["role"] != "system" else msg["content"]
while True:
parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
break
convo = convo[2:]
return "\n".join([system_msg["content"]] + [render(m) for m in convo] + ["AI:"])
# ---------------------------------------------------------------------------
# 4. Chat callback with immediate user echo & spinner
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
user_msg = strip(user_msg or "")
if not user_msg:
return display_history, state
if len(user_msg) > MAX_INPUT_CH:
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
return display_history, state
if MODEL_ERR:
display_history.append((user_msg, MODEL_ERR))
return display_history, state
# Immediately append the user message with a placeholder
display_history.append((user_msg, ""))
# Update raw history for prompt
state["raw"].append({"role": "user", "content": user_msg})
prompt = build_prompt(state["raw"])
# Generate the bot reply
try:
start = time.time()
out = generator(prompt)[0]["generated_text"].strip()
# Truncate any hallucinated next "User:"
if "User:" in out:
out = out.split("User:", 1)[0].strip()
reply = out
log(f"Reply in {time.time()-start:.2f}s ({len(reply)} chars)")
except Exception:
log("❌ Inference error:\n" + traceback.format_exc())
reply = "Apologies—an internal error occurred. Please try again."
# Replace the placeholder with the actual reply
display_history[-1] = (user_msg, reply)
state["raw"].append({"role": "assistant", "content": reply})
return display_history, state
# ---------------------------------------------------------------------------
# 5. Launch Gradio Blocks UI with spinner
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI")
state = gr.State({"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME_MSG},
]})
with gr.Row():
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
send_btn = gr.Button("Send", variant="primary")
# show_progress=True displays a spinner while waiting
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
demo.launch()