Spaces:
Paused
Paused
import os, re, time, datetime, traceback, torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from transformers.utils import logging as hf_logging | |
# ------------------------------------------------------------------- | |
# 1. Logging helpers | |
# ------------------------------------------------------------------- | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
LOG_FILE = "/data/requests.log" | |
def log(msg: str): | |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3] | |
line = f"[{ts}] {msg}" | |
print(line, flush=True) | |
try: | |
with open(LOG_FILE, "a") as f: | |
f.write(line + "\n") | |
except FileNotFoundError: | |
pass | |
# ------------------------------------------------------------------- | |
# 2. Configuration | |
# ------------------------------------------------------------------- | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300 | |
SYSTEM_MSG = ( | |
"You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, " | |
"founded by Charles Norton in 2025. The company installs on‑prem AI chat " | |
"mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to " | |
"K‑12 schools.\n\n" | |
"GUIDELINES:\n" | |
"• Warm, encouraging tone for students, parents, staff.\n" | |
"• Replies ≤ 4 sentences unless asked for detail.\n" | |
"• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n" | |
"• No personal‑data collection or sensitive advice.\n" | |
"• No profanity, politics, or mature themes." | |
) | |
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?" | |
def strip(s: str) -> str: | |
return re.sub(r"\s+", " ", s.strip()) | |
# ------------------------------------------------------------------- | |
# 3. Load model (GPU FP‑16 → CPU fallback) | |
# ------------------------------------------------------------------- | |
hf_logging.set_verbosity_error() | |
try: | |
log("Loading tokenizer …") | |
tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
if torch.cuda.is_available(): | |
log("GPU detected → FP‑16") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
) | |
else: | |
log("CPU fallback") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True | |
) | |
gen = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tok, | |
max_new_tokens=MAX_TOKENS, | |
do_sample=True, | |
temperature=0.6, | |
) | |
MODEL_ERR = None | |
log("Model loaded ✔") | |
except Exception as exc: # noqa: BLE001 | |
MODEL_ERR, gen = f"Model load error: {exc}", None | |
log(MODEL_ERR) | |
# ------------------------------------------------------------------- | |
# 4. Chat callback | |
# ------------------------------------------------------------------- | |
def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict): | |
""" | |
history: list of (user, assistant) tuples (Gradio default) | |
state : dict carrying system_prompt + raw_history for the model | |
Returns updated history (for UI) and state (for next round) | |
""" | |
if MODEL_ERR: | |
return history + [(user_msg, MODEL_ERR)], state | |
user_msg = strip(user_msg or "") | |
if not user_msg: | |
return history + [(user_msg, "Please type something.")], state | |
if len(user_msg) > MAX_INPUT_CH: | |
warn = f"Message too long (>{MAX_INPUT_CH} chars)." | |
return history + [(user_msg, warn)], state | |
# ------------------------------------------------ Prompt assembly | |
raw_hist = state.get("raw", []) | |
raw_hist.append({"role": "user", "content": user_msg}) | |
# keep system + last N exchanges | |
convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :] | |
raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo | |
prompt = "\n".join( | |
[ | |
m["content"] | |
if m["role"] == "system" | |
else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}' | |
for m in raw_hist | |
] | |
+ ["AI:"] | |
) | |
try: | |
raw = gen(prompt)[0]["generated_text"] | |
reply = strip(raw.split("AI:", 1)[-1]) | |
reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip() | |
except Exception: | |
log("❌ Inference error:\n" + traceback.format_exc()) | |
reply = "Sorry—backend crashed. Please try again later." | |
# ------------------------------------------------ Update state + UI history | |
raw_hist.append({"role": "assistant", "content": reply}) | |
state["raw"] = raw_hist | |
history.append((user_msg, reply)) | |
return history, state | |
# ------------------------------------------------------------------- | |
# 5. Launch | |
# ------------------------------------------------------------------- | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
chatbot = gr.Chatbot( | |
value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI" | |
) | |
state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]}) | |
with gr.Row(): | |
txt = gr.Textbox( | |
scale=4, placeholder="Type your question here...", show_label=False | |
) | |
send = gr.Button("Send", variant="primary") | |
send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state]) | |
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state]) | |
demo.launch() | |