Spaces:
Paused
Paused
import os, re, time, datetime, traceback, torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from transformers.utils import logging as hf_logging | |
# --------------------------------------------------------------------------- | |
# Logging helpers | |
# --------------------------------------------------------------------------- | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
LOG_FILE = "/data/requests.log" | |
def log(msg: str): | |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3] | |
line = f"[{ts}] {msg}" | |
print(line, flush=True) | |
try: | |
with open(LOG_FILE, "a") as f: | |
f.write(line + "\n") | |
except FileNotFoundError: | |
pass | |
# --------------------------------------------------------------------------- | |
# Configuration | |
# --------------------------------------------------------------------------- | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300 | |
SYSTEM_MSG = ( | |
"You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, " | |
"founded by Charles Norton in 2025. The company installs on‑prem AI chat " | |
"mascots, offers custom fine‑tuning of language models, and ships turnkey " | |
"PC's with preinstalled language models to K‑12 schools.\n\n" | |
"GUIDELINES:\n" | |
"• Use a warm, encouraging tone fit for students, parents, and staff.\n" | |
"• Keep replies short—no more than four sentences unless asked.\n" | |
"• If you’re unsure or out of scope, say so and suggest human follow‑up.\n" | |
"• Never collect personal data or provide medical, legal, or financial advice.\n" | |
"• No profanity, politics, or mature themes." | |
) | |
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?" | |
# --------------------------------------------------------------------------- | |
# Load model (GPU FP‑16 if available → CPU fallback) | |
# --------------------------------------------------------------------------- | |
hf_logging.set_verbosity_error() | |
try: | |
log("Loading tokenizer …") | |
tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
if torch.cuda.is_available(): | |
log("GPU detected → loading model in FP‑16") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", # put layers on available GPU(s) | |
torch_dtype=torch.float16, | |
) | |
else: | |
log("No GPU → loading model on CPU (FP‑32)") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="cpu", | |
torch_dtype="auto", | |
low_cpu_mem_usage=True, | |
) | |
gen = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tok, | |
max_new_tokens=MAX_TOKENS, | |
do_sample=True, | |
temperature=0.6, | |
) | |
MODEL_ERR = None | |
log("Model loaded ✔") | |
except Exception as exc: # noqa: BLE001 | |
MODEL_ERR, gen = f"Model load error: {exc}", None | |
log(MODEL_ERR) | |
clean = lambda t: re.sub(r"\s+", " ", t.strip()) or "…" | |
trim = lambda m: m if len(m) <= 1 + MAX_TURNS * 2 else [m[0]] + m[-MAX_TURNS * 2 :] | |
# --------------------------------------------------------------------------- | |
# Chat logic | |
# --------------------------------------------------------------------------- | |
def chat_fn(user_msg: str, history: list): | |
log(f"User sent {len(user_msg)} chars") | |
if not history or history[0]["role"] != "system": | |
history.insert(0, {"role": "system", "content": SYSTEM_MSG}) | |
if MODEL_ERR: | |
return MODEL_ERR | |
user_msg = clean(user_msg or "") | |
if not user_msg: | |
return "Please type something." | |
if len(user_msg) > MAX_INPUT_CH: | |
return f"Message too long (>{MAX_INPUT_CH} chars)." | |
history.append({"role": "user", "content": user_msg}) | |
history = trim(history) | |
prompt_lines = [ | |
m["content"] | |
if m["role"] == "system" | |
else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}' | |
for m in history | |
] + ["AI:"] | |
prompt = "\n".join(prompt_lines) | |
log(f"Prompt {len(prompt)} chars → generating") | |
t0 = time.time() | |
try: | |
raw = gen(prompt)[0]["generated_text"] | |
reply = clean(raw.split("AI:", 1)[-1]) | |
reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip() | |
log(f"generate() {time.time()-t0:.2f}s, reply {len(reply)} chars") | |
except Exception: | |
log("❌ Inference exception:\n" + traceback.format_exc()) | |
reply = "Sorry—backend crashed. Please try again later." | |
return reply | |
# --------------------------------------------------------------------------- | |
# UI | |
# --------------------------------------------------------------------------- | |
gr.ChatInterface( | |
fn=chat_fn, | |
chatbot=gr.Chatbot( | |
height=480, | |
type="messages", | |
value=[{"role": "assistant", "content": WELCOME_MSG}], | |
), | |
title="SchoolSpirit AI Chat", | |
theme=gr.themes.Soft(primary_hue="blue"), | |
type="messages", | |
).launch() | |