Spaces:
Paused
Paused
File size: 7,650 Bytes
2e445c2 e8212fa 502a6b6 e1aeada 2966210 2e445c2 2966210 9843b35 e8212fa 502a6b6 e8212fa e4f3d46 47f8251 e8212fa 2966210 2e445c2 2966210 e4f3d46 2966210 e4f3d46 e8212fa 2e445c2 2966210 2e445c2 8a16381 502a6b6 2966210 ef0a942 2966210 ef0a942 2966210 2e445c2 2966210 502a6b6 2e445c2 2966210 e8212fa 502a6b6 2e445c2 502a6b6 2e445c2 502a6b6 e4f3d46 502a6b6 ef0a942 2966210 9007cad 2966210 e63a535 2e445c2 e4f3d46 9007cad e8212fa 2e445c2 2966210 e8212fa 2966210 e4f3d46 2966210 2e445c2 e4f3d46 2966210 e4f3d46 2966210 e4f3d46 2966210 2e445c2 e4f3d46 2966210 2e445c2 e4f3d46 2966210 2e445c2 e4f3d46 885a86a 502a6b6 47f8251 e4f3d46 ef0a942 e4f3d46 47f8251 e4f3d46 47f8251 e4f3d46 47f8251 94cf8c8 e4f3d46 94cf8c8 2966210 2e445c2 2966210 94cf8c8 2966210 94cf8c8 47f8251 2966210 47f8251 94cf8c8 2966210 e4f3d46 2966210 94cf8c8 e4f3d46 2e445c2 94cf8c8 47f8251 94cf8c8 e4f3d46 47f8251 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ---------------------------------------------------------------------------
# 0. Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except Exception:
pass # tolerate logging failures
# ---------------------------------------------------------------------------
# 1. Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CONTEXT_TOKENS = 1800
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the official digital mascot of "
"SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
"deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
"turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Friendly, concise (≤4 sentences unless prompted).\n"
"• No personal data collection; no medical/legal/financial advice.\n"
"• If uncertain, admit it & suggest human follow‑up.\n"
"• Avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
# ---------------------------------------------------------------------------
# 2. Load tokenizer + model (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
log("GPU detected → FP‑16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype=torch.float16
)
else:
log("CPU fallback")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
return_full_text=False,
streaming=True, # ← enable token-by-token streaming
)
MODEL_ERR = None
log("Model loaded ✔")
except Exception as exc:
MODEL_ERR = f"Model load error: {exc}"
generator = None
log(MODEL_ERR)
# ---------------------------------------------------------------------------
# 3. Helper: build prompt under token budget (with fallback on error)
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
try:
def render(msg):
if msg["role"] == "system":
return msg["content"]
prefix = "User:" if msg["role"] == "user" else "AI:"
return f"{prefix} {msg['content']}"
system_msg = next(m for m in raw_history if m["role"] == "system")
convo = [m for m in raw_history if m["role"] != "system"]
while True:
parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
break
convo = convo[2:]
return "\n".join(parts)
except Exception:
log("Error building prompt:\n" + traceback.format_exc())
# Fallback: include system + last two messages
sys_text = next((m["content"] for m in raw_history if m["role"]=="system"), "")
tail = [m for m in raw_history if m["role"]!="system"][-2:]
fallback = sys_text + "\n" + "\n".join(
f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" for m in tail
) + "\nAI:"
return fallback
# ---------------------------------------------------------------------------
# 4. Chat callback (streaming generator + robust error handling)
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
user_msg = strip(user_msg or "")
if not user_msg:
yield display_history, state
return
if len(user_msg) > MAX_INPUT_CH:
display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
yield display_history, state
return
if MODEL_ERR:
display_history.append((user_msg, MODEL_ERR))
yield display_history, state
return
try:
# record user
state["raw"].append({"role": "user", "content": user_msg})
display_history.append((user_msg, ""))
prompt = build_prompt(state["raw"])
start = time.time()
partial = ""
# stream chunks
for chunk in generator(prompt):
try:
new_text = strip(chunk.get("generated_text", ""))
if "User:" in new_text:
new_text = new_text.split("User:", 1)[0].strip()
partial += new_text
display_history[-1] = (user_msg, partial)
yield display_history, state
except Exception:
log("Malformed chunk:\n" + traceback.format_exc())
continue
# finalize
full_reply = display_history[-1][1]
state["raw"].append({"role": "assistant", "content": full_reply})
log(f"Reply in {time.time() - start:.2f}s ({len(full_reply)} chars)")
except Exception:
log("Unexpected chat_fn error:\n" + traceback.format_exc())
err = "Apologies—an internal error occurred. Please try again."
display_history[-1] = (user_msg, err)
state["raw"].append({"role": "assistant", "content": err})
yield display_history, state
# ---------------------------------------------------------------------------
# 5. Launch Gradio Blocks UI
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
chatbot = gr.Chatbot(
value=[{"role":"assistant","content":WELCOME_MSG}],
height=480,
label="SchoolSpirit AI",
type="messages", # use the new messages format
)
state = gr.State(
{"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME_MSG},
]}
)
with gr.Row():
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
send_btn = gr.Button("Send", variant="primary")
# Use streaming=True (not stream) per Gradio API
send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
if __name__ == "__main__":
try:
demo.launch()
except Exception:
log("UI launch error:\n" + traceback.format_exc())
|