Spaces:
Paused
Paused
File size: 6,442 Bytes
2e445c2 e8212fa 502a6b6 e1aeada 2966210 2e445c2 2966210 9843b35 e8212fa 56e0226 502a6b6 e8212fa 56e0226 e8212fa 2966210 2e445c2 2966210 e4f3d46 2966210 e4f3d46 e8212fa 2e445c2 2966210 2e445c2 8a16381 502a6b6 2966210 ef0a942 2966210 ef0a942 56e0226 2966210 2e445c2 2966210 502a6b6 2e445c2 2966210 e8212fa 502a6b6 2e445c2 502a6b6 2e445c2 502a6b6 56e0226 502a6b6 ef0a942 2966210 9007cad 2966210 e63a535 2e445c2 e4f3d46 9007cad e8212fa 2e445c2 2966210 e8212fa 56e0226 2966210 56e0226 2966210 2e445c2 56e0226 2966210 56e0226 2966210 2e445c2 56e0226 2966210 2e445c2 56e0226 2966210 2e445c2 56e0226 885a86a 56e0226 ef0a942 56e0226 94cf8c8 56e0226 e4f3d46 56e0226 94cf8c8 2966210 56e0226 2966210 94cf8c8 2966210 56e0226 2966210 94cf8c8 e4f3d46 2e445c2 94cf8c8 56e0226 94cf8c8 56e0226 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ---------------------------------------------------------------------------
# 0. Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# ---------------------------------------------------------------------------
# 1. Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CONTEXT_TOKENS = 1800
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the official digital mascot of "
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
"turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Friendly, concise (≤4 sentences unless prompted).\n"
"• No personal data collection; no medical/legal/financial advice.\n"
"• If uncertain, admit it & suggest human follow‑up.\n"
"• Avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
# ---------------------------------------------------------------------------
# 2. Load tokenizer + model (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cpu",
torch_dtype="auto",
low_cpu_mem_usage=True,
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
return_full_text=False,
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc:
MODEL_ERR = f"Model load error: {exc}"
generator = None
log(MODEL_ERR)
# ---------------------------------------------------------------------------
# 3. Helper: build prompt under token budget
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
system_msg = [m for m in raw_history if m["role"] == "system"][0]
convo = [m for m in raw_history if m["role"] != "system"]
def render(msg):
prefix = "User:" if msg["role"] == "user" else "AI:"
return f"{prefix} {msg['content']}" if msg["role"] != "system" else msg["content"]
while True:
parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
break
convo = convo[2:]
return "\n".join([system_msg["content"]] + [render(m) for m in convo] + ["AI:"])
# ---------------------------------------------------------------------------
# 4. Chat callback with immediate user echo & spinner
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
user_msg = strip(user_msg or "")
if not user_msg:
return display_history, state
if len(user_msg) > MAX_INPUT_CH:
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
return display_history, state
if MODEL_ERR:
display_history.append((user_msg, MODEL_ERR))
return display_history, state
# Immediately append the user message with a placeholder
display_history.append((user_msg, ""))
# Update raw history for prompt
state["raw"].append({"role": "user", "content": user_msg})
prompt = build_prompt(state["raw"])
# Generate the bot reply
try:
start = time.time()
out = generator(prompt)[0]["generated_text"].strip()
# Truncate any hallucinated next "User:"
if "User:" in out:
out = out.split("User:", 1)[0].strip()
reply = out
log(f"Reply in {time.time()-start:.2f}s ({len(reply)} chars)")
except Exception:
log("❌ Inference error:\n" + traceback.format_exc())
reply = "Apologies—an internal error occurred. Please try again."
# Replace the placeholder with the actual reply
display_history[-1] = (user_msg, reply)
state["raw"].append({"role": "assistant", "content": reply})
return display_history, state
# ---------------------------------------------------------------------------
# 5. Launch Gradio Blocks UI with spinner
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI")
state = gr.State({"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME_MSG},
]})
with gr.Row():
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
send_btn = gr.Button("Send", variant="primary")
# show_progress=True displays a spinner while waiting
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
demo.launch()
|