File size: 6,442 Bytes
2e445c2
 
 
 
 
 
e8212fa
502a6b6
e1aeada
 
2966210
2e445c2
2966210
9843b35
e8212fa
 
56e0226
502a6b6
 
e8212fa
 
 
 
 
56e0226
 
 
e8212fa
2966210
2e445c2
2966210
e4f3d46
 
2966210
 
e4f3d46
e8212fa
 
2e445c2
 
 
 
2966210
2e445c2
 
 
8a16381
502a6b6
2966210
ef0a942
2966210
ef0a942
56e0226
2966210
2e445c2
2966210
502a6b6
 
2e445c2
2966210
e8212fa
502a6b6
2e445c2
502a6b6
 
 
 
2e445c2
502a6b6
56e0226
 
 
 
502a6b6
ef0a942
2966210
9007cad
 
2966210
 
e63a535
2e445c2
e4f3d46
9007cad
e8212fa
2e445c2
2966210
 
 
e8212fa
 
56e0226
2966210
56e0226
2966210
2e445c2
56e0226
 
 
 
 
 
 
 
 
 
 
 
 
 
2966210
 
 
56e0226
2966210
2e445c2
 
 
56e0226
2966210
2e445c2
 
56e0226
2966210
 
2e445c2
56e0226
885a86a
56e0226
 
ef0a942
56e0226
 
 
94cf8c8
56e0226
 
 
 
 
 
 
 
 
e4f3d46
56e0226
 
 
 
 
 
 
 
 
94cf8c8
2966210
56e0226
2966210
94cf8c8
2966210
 
56e0226
 
 
 
 
2966210
94cf8c8
e4f3d46
2e445c2
94cf8c8
56e0226
 
 
94cf8c8
56e0226
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging

# ---------------------------------------------------------------------------
# 0.  Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"


def log(msg: str):
    ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
    line = f"[{ts}] {msg}"
    print(line, flush=True)
    try:
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except FileNotFoundError:
        pass


# ---------------------------------------------------------------------------
# 1.  Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CONTEXT_TOKENS = 1800
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300

SYSTEM_MSG = (
    "You are **SchoolSpirit AI**, the official digital mascot of "
    "SchoolSpirit AI LLC.  Founded by Charles Norton in 2025, the company "
    "deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
    "turnkey GPU servers to K‑12 schools.\n\n"
    "RULES:\n"
    "• Friendly, concise (≤4 sentences unless prompted).\n"
    "• No personal data collection; no medical/legal/financial advice.\n"
    "• If uncertain, admit it & suggest human follow‑up.\n"
    "• Avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"

strip = lambda s: re.sub(r"\s+", " ", s.strip())


# ---------------------------------------------------------------------------
# 2.  Load tokenizer + model  (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
    log("Loading tokenizer …")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    if torch.cuda.is_available():
        log("GPU detected → FP‑16")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="auto", torch_dtype=torch.float16
        )
    else:
        log("CPU fallback")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            device_map="cpu",
            torch_dtype="auto",
            low_cpu_mem_usage=True,
        )

    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=TEMPERATURE,
        return_full_text=False,
    )
    MODEL_ERR = None
    log("Model loaded ✔")
except Exception as exc:
    MODEL_ERR = f"Model load error: {exc}"
    generator = None
    log(MODEL_ERR)


# ---------------------------------------------------------------------------
# 3.  Helper: build prompt under token budget
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
    system_msg = [m for m in raw_history if m["role"] == "system"][0]
    convo = [m for m in raw_history if m["role"] != "system"]

    def render(msg):
        prefix = "User:" if msg["role"] == "user" else "AI:"
        return f"{prefix} {msg['content']}" if msg["role"] != "system" else msg["content"]

    while True:
        parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
        token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
        if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
            break
        convo = convo[2:]
    return "\n".join([system_msg["content"]] + [render(m) for m in convo] + ["AI:"])


# ---------------------------------------------------------------------------
# 4.  Chat callback with immediate user echo & spinner
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
    user_msg = strip(user_msg or "")
    if not user_msg:
        return display_history, state

    if len(user_msg) > MAX_INPUT_CH:
        display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
        return display_history, state

    if MODEL_ERR:
        display_history.append((user_msg, MODEL_ERR))
        return display_history, state

    # Immediately append the user message with a placeholder
    display_history.append((user_msg, ""))

    # Update raw history for prompt
    state["raw"].append({"role": "user", "content": user_msg})
    prompt = build_prompt(state["raw"])

    # Generate the bot reply
    try:
        start = time.time()
        out = generator(prompt)[0]["generated_text"].strip()
        # Truncate any hallucinated next "User:"
        if "User:" in out:
            out = out.split("User:", 1)[0].strip()
        reply = out
        log(f"Reply in {time.time()-start:.2f}s ({len(reply)} chars)")
    except Exception:
        log("❌ Inference error:\n" + traceback.format_exc())
        reply = "Apologies—an internal error occurred. Please try again."

    # Replace the placeholder with the actual reply
    display_history[-1] = (user_msg, reply)
    state["raw"].append({"role": "assistant", "content": reply})

    return display_history, state


# ---------------------------------------------------------------------------
# 5.  Launch Gradio Blocks UI with spinner
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    gr.Markdown("### SchoolSpirit AI Chat")

    chatbot = gr.Chatbot(value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI")
    state = gr.State({"raw": [
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "assistant", "content": WELCOME_MSG},
    ]})

    with gr.Row():
        txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
        send_btn = gr.Button("Send", variant="primary")

    # show_progress=True displays a spinner while waiting
    send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)
    txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], show_progress=True)

demo.launch()