File size: 7,258 Bytes
d0ad708
 
0ea4bc5
d0ad708
e1aeada
 
d0ad708
9843b35
e8212fa
0ea4bc5
502a6b6
e8212fa
 
0ea4bc5
 
 
 
 
e8212fa
d0ad708
2cb9530
999c346
b12d444
999c346
2cb9530
d0ad708
e8212fa
 
0ea4bc5
d0ad708
 
 
0ea4bc5
 
 
 
 
 
502a6b6
b12d444
ef0a942
2966210
ef0a942
d0ad708
502a6b6
 
d0ad708
2cb9530
0ea4bc5
d0ad708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8212fa
2e445c2
2966210
999c346
d0ad708
e8212fa
d0ad708
0ea4bc5
 
2cb9530
d0ad708
 
0ea4bc5
999c346
2cb9530
56e0226
d0ad708
0ea4bc5
999c346
d0ad708
0ea4bc5
d0ad708
999c346
 
56e0226
0ea4bc5
 
 
 
 
d0ad708
 
999c346
 
d0ad708
 
2cb9530
2e445c2
0ea4bc5
d0ad708
2e445c2
d0ad708
 
2966210
d0ad708
 
885a86a
d0ad708
 
0ea4bc5
 
56e0226
0ea4bc5
94cf8c8
999c346
0ea4bc5
 
 
 
 
 
 
 
 
999c346
 
d0ad708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94cf8c8
2966210
0ea4bc5
 
 
 
 
 
 
94cf8c8
0ea4bc5
2cb9530
0ea4bc5
 
94cf8c8
56e0226
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# app.py  β€’  SchoolSpiritΒ AI chatbot Space
# Granite‑3.3‑2B‑Instruct  |  Streaming + rate‑limit + hallucination guard
import os, re, time, datetime, threading, traceback, torch, gradio as gr
from transformers import (AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer)
from transformers.utils import logging as hf_logging

# ───────────────────────────────── Log helper ────────────────────────────────
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(msg: str):
    ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
    line = f"[{ts}] {msg}"
    print(line, flush=True)
    try:
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except FileNotFoundError:
        pass

# ─────────────────────────────── Configuration ───────────────────────────────
MODEL_ID          = "ibm-granite/granite-3.3-2b-instruct"
CTX_TOKENS        = 1800
MAX_NEW_TOKENS    = 120
TEMP              = 0.6
MAX_INPUT_CH      = 300
RATE_N, RATE_SEC  = 5, 60      # 5 msgs / 60Β s per IP

SYSTEM_MSG = (
    "You are **SchoolSpiritΒ AI**, the friendly digital mascot of "
    "SchoolSpiritΒ AIΒ LLC, founded by CharlesΒ Norton inΒ 2025. "
    "The company installs on‑prem AI chat mascots, fine‑tunes language models, "
    "and ships turnkey GPU servers to K‑12 schools.\n\n"
    "RULES:\n"
    "β€’ Reply in ≀ 4 sentences unless asked for detail.\n"
    "β€’ No personal‑data collection; no medical/legal/financial advice.\n"
    "β€’ If uncertain, say so and suggest contacting a human.\n"
    "β€’ If you can’t answer, politely direct the user to admin@schoolspiritai.com.\n"
    "β€’ Keep language age‑appropriate; avoid profanity, politics, mature themes."
)
WELCOME = "HiΒ there! I’m SchoolSpiritΒ AI. Ask me anything about our services!"

strip = lambda s: re.sub(r"\s+", " ", s.strip())

# ─────────────────────── Load tokenizer & model ──────────────────────────────
hf_logging.set_verbosity_error()
try:
    log("Loading tokenizer …")
    tok = AutoTokenizer.from_pretrained(MODEL_ID)

    if torch.cuda.is_available():
        log("GPU detected β†’ loading model in FP‑16")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            device_map="auto",
            torch_dtype=torch.float16,
        )
    else:
        log("No GPU β†’ loading model on CPU (this is slower)")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            device_map="cpu",
            torch_dtype="auto",
            low_cpu_mem_usage=True,
        )

    MODEL_ERR = None
    log("Model loaded βœ”")
except Exception as exc:
    MODEL_ERR = f"Model load error: {exc}"
    log("❌ " + MODEL_ERR + "\n" + traceback.format_exc())

# ────────────────────────── Per‑IP rate limiter ──────────────────────────────
VISITS: dict[str, list[float]] = {}
def allowed(ip: str) -> bool:
    now = time.time()
    VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_SEC]
    if len(VISITS[ip]) >= RATE_N:
        return False
    VISITS[ip].append(now)
    return True

# ─────────────────────── Prompt builder (token budget) ───────────────────────
def build_prompt(raw: list[dict]) -> str:
    def render(m):
        if m["role"] == "system":
            return m["content"]
        prefix = "User:" if m["role"] == "user" else "AI:"
        return f"{prefix} {m['content']}"
    system, convo = raw[0], raw[1:]
    while True:
        parts = [system["content"]] + [render(m) for m in convo] + ["AI:"]
        if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOKENS or len(convo) <= 2:
            return "\n".join(parts)
        convo = convo[2:]  # drop oldest user+assistant pair

# ───────────────────────── Streaming chat callback ───────────────────────────
def chat_fn(user_msg, chat_hist, state, request: gr.Request):
    ip = request.client.host if request else "anon"
    if not allowed(ip):
        chat_hist.append((user_msg, "Rate limit exceeded β€” please wait a minute."))
        return chat_hist, state

    user_msg = strip(user_msg or "")
    if not user_msg:
        return chat_hist, state
    if len(user_msg) > MAX_INPUT_CH:
        chat_hist.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
        return chat_hist, state
    if MODEL_ERR:
        chat_hist.append((user_msg, MODEL_ERR))
        return chat_hist, state

    # append user turn & empty assistant slot
    chat_hist.append((user_msg, ""))
    state["raw"].append({"role": "user", "content": user_msg})

    prompt = build_prompt(state["raw"])
    input_ids = tok(prompt, return_tensors="pt").to(model.device).input_ids

    streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
    threading.Thread(
        target=model.generate,
        kwargs=dict(
            input_ids=input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMP,
            streamer=streamer,
        ),
    ).start()

    partial = ""
    try:
        for token in streamer:
            partial += token
            # hallucination guard: stop if model starts new speaker tag
            if "User:" in partial or "\nAI:" in partial:
                partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip()
                break
            chat_hist[-1] = (user_msg, partial)
            yield chat_hist, state
    except Exception as exc:
        log("❌ Stream error:\n" + traceback.format_exc())
        partial = "Apologiesβ€”internal error. Please try again."

    reply = strip(partial)
    chat_hist[-1] = (user_msg, reply)
    state["raw"].append({"role": "assistant", "content": reply})
    yield chat_hist, state  # final

# ─────────────────────────── Gradio Blocks UI ────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    gr.Markdown("### SchoolSpiritΒ AI Chat")
    bot = gr.Chatbot(value=[("", WELCOME)], height=480, label="SchoolSpiritΒ AI")
    st  = gr.State({
        "raw": [
            {"role": "system", "content": SYSTEM_MSG},
            {"role": "assistant", "content": WELCOME},
        ]
    })
    with gr.Row():
        txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
        btn = gr.Button("Send", variant="primary")
    btn.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st])
    txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st])

demo.launch()