Spaces:
Paused
Paused
import os | |
import re | |
import time | |
import datetime | |
import traceback | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from transformers.utils import logging as hf_logging | |
# --------------------------------------------------------------------------- | |
# 0. Paths & basic logging helper | |
# --------------------------------------------------------------------------- | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
LOG_FILE = "/data/requests.log" | |
def log(msg: str): | |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3] | |
line = f"[{ts}] {msg}" | |
print(line, flush=True) | |
try: | |
with open(LOG_FILE, "a") as f: | |
f.write(line + "\n") | |
except FileNotFoundError: | |
pass | |
# --------------------------------------------------------------------------- | |
# 1. Configuration constants | |
# --------------------------------------------------------------------------- | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
CONTEXT_TOKENS = 1800 | |
MAX_NEW_TOKENS = 64 | |
TEMPERATURE = 0.6 | |
MAX_INPUT_CH = 300 | |
SYSTEM_MSG = ( | |
"You are **SchoolSpirit AI**, the official digital mascot of " | |
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company " | |
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships " | |
"turnkey GPU servers to K‑12 schools.\n\n" | |
"RULES:\n" | |
"• Friendly, concise (≤4 sentences unless prompted).\n" | |
"• No personal data collection; no medical/legal/financial advice.\n" | |
"• If uncertain, admit it & suggest human follow‑up.\n" | |
"• Avoid profanity, politics, mature themes." | |
) | |
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?" | |
strip = lambda s: re.sub(r"\s+", " ", s.strip()) | |
# --------------------------------------------------------------------------- | |
# 2. Load tokenizer + model (GPU FP‑16 → CPU) | |
# --------------------------------------------------------------------------- | |
hf_logging.set_verbosity_error() | |
try: | |
log("Loading tokenizer …") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
if torch.cuda.is_available(): | |
log("GPU detected → FP‑16") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
) | |
else: | |
log("CPU fallback") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="cpu", | |
torch_dtype="auto", | |
low_cpu_mem_usage=True, | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
return_full_text=False, | |
) | |
MODEL_ERR = None | |
log("Model loaded ✔") | |
except Exception as exc: | |
MODEL_ERR = f"Model load error: {exc}" | |
generator = None | |
log(MODEL_ERR) | |
# --------------------------------------------------------------------------- | |
# 3. Helper: build prompt under token budget | |
# --------------------------------------------------------------------------- | |
def build_prompt(raw_history: list[dict]) -> str: | |
system_msg = [m for m in raw_history if m["role"] == "system"][0] | |
convo = [m for m in raw_history if m["role"] != "system"] | |
def render(msg): | |
prefix = "User:" if msg["role"] == "user" else "AI:" | |
return f"{prefix} {msg['content']}" if msg["role"] != "system" else msg["content"] | |
while True: | |
parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"] | |
token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False)) | |
if token_len <= CONTEXT_TOKENS or len(convo) <= 2: | |
break | |
convo = convo[2:] | |
return "\n".join([system_msg["content"]] + [render(m) for m in convo] + ["AI:"]) | |
# --------------------------------------------------------------------------- | |
# 4. Chat callback with immediate user echo & spinner | |
# --------------------------------------------------------------------------- | |
def chat_fn(user_msg: str, display_history: list, state: dict): | |
user_msg = strip(user_msg or "") | |
if not user_msg: | |
return display_history, state | |
if len(user_msg) > MAX_INPUT_CH: | |
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars.")) | |
return display_history, state | |
if MODEL_ERR: | |
display_history.append((user_msg, MODEL_ERR)) | |
return display_history, state | |
# Immediately append the user message with a placeholder | |
display_history.append((user_msg, "")) | |
# Update raw history for prompt | |
state["raw"].append({"role": "user", "content": user_msg}) | |
prompt = build_prompt(state["raw"]) | |
# Generate the bot reply | |
try: | |
start = time.time() | |
out = generator(prompt)[0]["generated_text"].strip() | |
# Truncate any hallucinated next "User:" | |
if "User:" in out: | |
out = out.split("User:", 1)[0].strip() | |
reply = out | |
log(f"Reply in {time.time()-start:.2f}s ({len(reply)} chars)") | |
except Exception: | |
log("❌ Inference error:\n" + traceback.format_exc()) | |
reply = "Apologies—an internal error occurred. Please try again." | |
# Replace the placeholder with the actual reply | |
display_history[-1] = (user_msg, reply) | |
state["raw"].append({"role": "assistant", "content": reply}) | |
return display_history, state | |
# --------------------------------------------------------------------------- | |
# 5. Launch Gradio Blocks UI with spinner | |
# --------------------------------------------------------------------------- | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown("### SchoolSpirit AI Chat") | |
chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI") | |
state = gr.State({"raw": [ | |
{"role": "system", "content": SYSTEM_MSG}, | |
{"role": "assistant", "content": WELCOME_MSG}, | |
]}) | |
with gr.Row(): | |
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1) | |
send_btn = gr.Button("Send", variant="primary") | |
# show_progress=True displays a spinner while waiting | |
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True) | |
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True) | |
demo.launch() | |