SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
e8212fa verified
raw
history blame
4.26 kB
import os, re, time, datetime, traceback, torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ---------- Logging ---------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# ---------- Config ----------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 6, 128, 400
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
"founded by Charles Norton in 2025. The company installs on‑prem AI chat "
"mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to schools.\n\n"
"Guidelines:\n"
"• Warm, concise answers (max 4 sentences).\n"
"• No personal‑data collection or sensitive advice.\n"
"• If unsure, say so and suggest a human follow‑up.\n"
"• Avoid profanity, politics, or mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
# ---------- Load model (GPU FP‑16 → CPU fallback) ---------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tok = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tok,
max_new_tokens=MAX_TOKENS,
do_sample=True,
temperature=0.6,
pad_token_id=tok.eos_token_id,
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc: # noqa: BLE001
MODEL_ERR, gen = f"Model load error: {exc}", None
log(MODEL_ERR)
# ---------- Chat callback ---------------------------------------------------
def chat_fn(user_msg: str, history: list[dict]):
"""
history comes in/out as list[{'role':'user'|'assistant','content':str}, …]
"""
if MODEL_ERR:
return history + [{"role": "assistant", "content": MODEL_ERR}]
user_msg = strip(user_msg or "")
if not user_msg:
return history + [{"role": "assistant", "content": "Please type something."}]
if len(user_msg) > MAX_INPUT_CH:
warn = f"Message too long (>{MAX_INPUT_CH} chars)."
return history + [{"role": "assistant", "content": warn}]
# Append user to history
history.append({"role": "user", "content": user_msg})
# Keep system + last N messages
convo = [m for m in history if m["role"] != "system"][-MAX_TURNS * 2 :]
prompt_parts = [SYSTEM_MSG] + [
f"{'User' if m['role']=='user' else 'AI'}: {m['content']}" for m in convo
] + ["AI:"]
prompt = "\n".join(prompt_parts)
try:
raw = gen(prompt)[0]["generated_text"]
reply = strip(raw.split("AI:", 1)[-1])
reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
except Exception:
log("❌ Inference error:\n" + traceback.format_exc())
reply = "Sorry—backend crashed. Please try again later."
history.append({"role": "assistant", "content": reply})
return history
# ---------- Launch ----------------------------------------------------------
gr.ChatInterface(
fn=chat_fn,
chatbot=gr.Chatbot(
height=480,
type="messages",
value=[
{"role": "assistant", "content": WELCOME_MSG}
], # ONE welcome bubble
),
additional_inputs=None,
title="SchoolSpirit AI Chat",
theme=gr.themes.Soft(primary_hue="blue"),
examples=None,
).launch()