SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
94cf8c8 verified
raw
history blame
5.49 kB
import os, re, time, datetime, traceback, torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# -------------------------------------------------------------------
# 1. Logging helpers
# -------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# -------------------------------------------------------------------
# 2. Configuration
# -------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
"founded by Charles Norton in 2025. The company installs on‑prem AI chat "
"mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
"K‑12 schools.\n\n"
"GUIDELINES:\n"
"• Warm, encouraging tone for students, parents, staff.\n"
"• Replies ≤ 4 sentences unless asked for detail.\n"
"• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
"• No personal‑data collection or sensitive advice.\n"
"• No profanity, politics, or mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
def strip(s: str) -> str:
return re.sub(r"\s+", " ", s.strip())
# -------------------------------------------------------------------
# 3. Load model (GPU FP‑16 → CPU fallback)
# -------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tok = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tok,
max_new_tokens=MAX_TOKENS,
do_sample=True,
temperature=0.6,
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc: # noqa: BLE001
MODEL_ERR, gen = f"Model load error: {exc}", None
log(MODEL_ERR)
# -------------------------------------------------------------------
# 4. Chat callback
# -------------------------------------------------------------------
def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
"""
history: list of (user, assistant) tuples (Gradio default)
state : dict carrying system_prompt + raw_history for the model
Returns updated history (for UI) and state (for next round)
"""
if MODEL_ERR:
return history + [(user_msg, MODEL_ERR)], state
user_msg = strip(user_msg or "")
if not user_msg:
return history + [(user_msg, "Please type something.")], state
if len(user_msg) > MAX_INPUT_CH:
warn = f"Message too long (>{MAX_INPUT_CH} chars)."
return history + [(user_msg, warn)], state
# ------------------------------------------------ Prompt assembly
raw_hist = state.get("raw", [])
raw_hist.append({"role": "user", "content": user_msg})
# keep system + last N exchanges
convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
prompt = "\n".join(
[
m["content"]
if m["role"] == "system"
else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
for m in raw_hist
]
+ ["AI:"]
)
try:
raw = gen(prompt)[0]["generated_text"]
reply = strip(raw.split("AI:", 1)[-1])
reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
except Exception:
log("❌ Inference error:\n" + traceback.format_exc())
reply = "Sorry—backend crashed. Please try again later."
# ------------------------------------------------ Update state + UI history
raw_hist.append({"role": "assistant", "content": reply})
state["raw"] = raw_hist
history.append((user_msg, reply))
return history, state
# -------------------------------------------------------------------
# 5. Launch
# -------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
chatbot = gr.Chatbot(
value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI"
)
state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
with gr.Row():
txt = gr.Textbox(
scale=4, placeholder="Type your question here...", show_label=False
)
send = gr.Button("Send", variant="primary")
send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
demo.launch()