Spaces:
Paused
Paused
File size: 5,485 Bytes
e8212fa 502a6b6 e1aeada 94cf8c8 9843b35 e8212fa 502a6b6 e8212fa 94cf8c8 e8212fa 94cf8c8 e8212fa 94cf8c8 e8212fa 94cf8c8 e8212fa 94cf8c8 502a6b6 94cf8c8 ef0a942 94cf8c8 ef0a942 94cf8c8 502a6b6 e8212fa 502a6b6 e8212fa 502a6b6 e8212fa 502a6b6 e8212fa 502a6b6 ef0a942 502a6b6 9007cad e8212fa 502a6b6 9961fac 9007cad e8212fa 61ca5d6 94cf8c8 e8212fa 94cf8c8 e8212fa 94cf8c8 318dc96 502a6b6 9007cad 94cf8c8 e8212fa 94cf8c8 885a86a 502a6b6 e8212fa 502a6b6 e8212fa 502a6b6 ef0a942 94cf8c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os, re, time, datetime, traceback, torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# -------------------------------------------------------------------
# 1. Logging helpers
# -------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# -------------------------------------------------------------------
# 2. Configuration
# -------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
"founded by Charles Norton in 2025. The company installs on‑prem AI chat "
"mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
"K‑12 schools.\n\n"
"GUIDELINES:\n"
"• Warm, encouraging tone for students, parents, staff.\n"
"• Replies ≤ 4 sentences unless asked for detail.\n"
"• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
"• No personal‑data collection or sensitive advice.\n"
"• No profanity, politics, or mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
def strip(s: str) -> str:
return re.sub(r"\s+", " ", s.strip())
# -------------------------------------------------------------------
# 3. Load model (GPU FP‑16 → CPU fallback)
# -------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tok = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tok,
max_new_tokens=MAX_TOKENS,
do_sample=True,
temperature=0.6,
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc: # noqa: BLE001
MODEL_ERR, gen = f"Model load error: {exc}", None
log(MODEL_ERR)
# -------------------------------------------------------------------
# 4. Chat callback
# -------------------------------------------------------------------
def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
"""
history: list of (user, assistant) tuples (Gradio default)
state : dict carrying system_prompt + raw_history for the model
Returns updated history (for UI) and state (for next round)
"""
if MODEL_ERR:
return history + [(user_msg, MODEL_ERR)], state
user_msg = strip(user_msg or "")
if not user_msg:
return history + [(user_msg, "Please type something.")], state
if len(user_msg) > MAX_INPUT_CH:
warn = f"Message too long (>{MAX_INPUT_CH} chars)."
return history + [(user_msg, warn)], state
# ------------------------------------------------ Prompt assembly
raw_hist = state.get("raw", [])
raw_hist.append({"role": "user", "content": user_msg})
# keep system + last N exchanges
convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
prompt = "\n".join(
[
m["content"]
if m["role"] == "system"
else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
for m in raw_hist
]
+ ["AI:"]
)
try:
raw = gen(prompt)[0]["generated_text"]
reply = strip(raw.split("AI:", 1)[-1])
reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
except Exception:
log("❌ Inference error:\n" + traceback.format_exc())
reply = "Sorry—backend crashed. Please try again later."
# ------------------------------------------------ Update state + UI history
raw_hist.append({"role": "assistant", "content": reply})
state["raw"] = raw_hist
history.append((user_msg, reply))
return history, state
# -------------------------------------------------------------------
# 5. Launch
# -------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
chatbot = gr.Chatbot(
value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI"
)
state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
with gr.Row():
txt = gr.Textbox(
scale=4, placeholder="Type your question here...", show_label=False
)
send = gr.Button("Send", variant="primary")
send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
demo.launch()
|