Spaces:
Paused
Paused
import os | |
import re | |
import time | |
import datetime | |
import traceback | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from transformers.utils import logging as hf_logging | |
# --------------------------------------------------------------------------- | |
# 0. Paths & basic logging helper | |
# --------------------------------------------------------------------------- | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
LOG_FILE = "/data/requests.log" | |
def log(msg: str): | |
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3] | |
line = f"[{ts}] {msg}" | |
print(line, flush=True) | |
try: | |
with open(LOG_FILE, "a") as f: | |
f.write(line + "\n") | |
except Exception: | |
pass # tolerate logging failures | |
# --------------------------------------------------------------------------- | |
# 1. Configuration constants | |
# --------------------------------------------------------------------------- | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
CONTEXT_TOKENS = 1800 | |
MAX_NEW_TOKENS = 64 | |
TEMPERATURE = 0.6 | |
MAX_INPUT_CH = 300 | |
SYSTEM_MSG = ( | |
"You are **SchoolSpirit AI**, the official digital mascot of " | |
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company " | |
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships " | |
"turnkey GPU servers to K‑12 schools.\n\n" | |
"RULES:\n" | |
"• Friendly, concise (≤4 sentences unless prompted).\n" | |
"• No personal data collection; no medical/legal/financial advice.\n" | |
"• If uncertain, admit it & suggest human follow‑up.\n" | |
"• Avoid profanity, politics, mature themes." | |
) | |
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?" | |
strip = lambda s: re.sub(r"\s+", " ", s.strip()) | |
# --------------------------------------------------------------------------- | |
# 2. Load tokenizer + model (GPU FP‑16 → CPU) | |
# --------------------------------------------------------------------------- | |
hf_logging.set_verbosity_error() | |
try: | |
log("Loading tokenizer …") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
if torch.cuda.is_available(): | |
log("GPU detected → FP‑16") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="auto", torch_dtype=torch.float16 | |
) | |
else: | |
log("CPU fallback") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True | |
) | |
generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=True, | |
temperature=TEMPERATURE, | |
return_full_text=False, | |
streaming=True, # ← enable token-by-token streaming | |
) | |
MODEL_ERR = None | |
log("Model loaded ✔") | |
except Exception as exc: | |
MODEL_ERR = f"Model load error: {exc}" | |
generator = None | |
log(MODEL_ERR) | |
# --------------------------------------------------------------------------- | |
# 3. Helper: build prompt under token budget (with fallback on error) | |
# --------------------------------------------------------------------------- | |
def build_prompt(raw_history: list[dict]) -> str: | |
try: | |
def render(msg): | |
if msg["role"] == "system": | |
return msg["content"] | |
prefix = "User:" if msg["role"] == "user" else "AI:" | |
return f"{prefix} {msg['content']}" | |
system_msg = next(m for m in raw_history if m["role"] == "system") | |
convo = [m for m in raw_history if m["role"] != "system"] | |
while True: | |
parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"] | |
token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False)) | |
if token_len <= CONTEXT_TOKENS or len(convo) <= 2: | |
break | |
convo = convo[2:] | |
return "\n".join(parts) | |
except Exception: | |
log("Error building prompt:\n" + traceback.format_exc()) | |
# Fallback: include system + last two messages | |
sys_text = next((m["content"] for m in raw_history if m["role"]=="system"), "") | |
tail = [m for m in raw_history if m["role"]!="system"][-2:] | |
fallback = sys_text + "\n" + "\n".join( | |
f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" for m in tail | |
) + "\nAI:" | |
return fallback | |
# --------------------------------------------------------------------------- | |
# 4. Chat callback (streaming generator + robust error handling) | |
# --------------------------------------------------------------------------- | |
def chat_fn(user_msg: str, display_history: list, state: dict): | |
user_msg = strip(user_msg or "") | |
if not user_msg: | |
yield display_history, state | |
return | |
if len(user_msg) > MAX_INPUT_CH: | |
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars.")) | |
yield display_history, state | |
return | |
if MODEL_ERR: | |
display_history.append((user_msg, MODEL_ERR)) | |
yield display_history, state | |
return | |
try: | |
# record user | |
state["raw"].append({"role": "user", "content": user_msg}) | |
display_history.append((user_msg, "")) | |
prompt = build_prompt(state["raw"]) | |
start = time.time() | |
partial = "" | |
# stream chunks | |
for chunk in generator(prompt): | |
try: | |
new_text = strip(chunk.get("generated_text", "")) | |
if "User:" in new_text: | |
new_text = new_text.split("User:", 1)[0].strip() | |
partial += new_text | |
display_history[-1] = (user_msg, partial) | |
yield display_history, state | |
except Exception: | |
log("Malformed chunk:\n" + traceback.format_exc()) | |
continue | |
# finalize | |
full_reply = display_history[-1][1] | |
state["raw"].append({"role": "assistant", "content": full_reply}) | |
log(f"Reply in {time.time() - start:.2f}s ({len(full_reply)} chars)") | |
except Exception: | |
log("Unexpected chat_fn error:\n" + traceback.format_exc()) | |
err = "Apologies—an internal error occurred. Please try again." | |
display_history[-1] = (user_msg, err) | |
state["raw"].append({"role": "assistant", "content": err}) | |
yield display_history, state | |
# --------------------------------------------------------------------------- | |
# 5. Launch Gradio Blocks UI | |
# --------------------------------------------------------------------------- | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown("### SchoolSpirit AI Chat") | |
chatbot = gr.Chatbot( | |
value=[{"role":"assistant","content":WELCOME_MSG}], | |
height=480, | |
label="SchoolSpirit AI", | |
type="messages", # use the new messages format | |
) | |
state = gr.State( | |
{"raw": [ | |
{"role": "system", "content": SYSTEM_MSG}, | |
{"role": "assistant", "content": WELCOME_MSG}, | |
]} | |
) | |
with gr.Row(): | |
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1) | |
send_btn = gr.Button("Send", variant="primary") | |
# Use streaming=True (not stream) per Gradio API | |
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True) | |
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True) | |
if __name__ == "__main__": | |
try: | |
demo.launch() | |
except Exception: | |
log("UI launch error:\n" + traceback.format_exc()) | |