File size: 7,650 Bytes
2e445c2
 
 
 
 
 
e8212fa
502a6b6
e1aeada
 
2966210
2e445c2
2966210
9843b35
e8212fa
 
502a6b6
 
e8212fa
 
 
 
 
e4f3d46
47f8251
e8212fa
2966210
2e445c2
2966210
e4f3d46
 
2966210
 
e4f3d46
e8212fa
 
2e445c2
 
 
 
2966210
2e445c2
 
 
8a16381
502a6b6
2966210
ef0a942
2966210
ef0a942
2966210
2e445c2
2966210
502a6b6
 
2e445c2
2966210
e8212fa
502a6b6
2e445c2
502a6b6
 
 
 
2e445c2
502a6b6
e4f3d46
502a6b6
ef0a942
2966210
9007cad
 
2966210
 
e63a535
2e445c2
e4f3d46
 
9007cad
e8212fa
2e445c2
2966210
 
 
e8212fa
 
2966210
e4f3d46
2966210
2e445c2
e4f3d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2966210
e4f3d46
 
 
 
 
 
 
 
 
2966210
 
e4f3d46
2966210
2e445c2
 
 
e4f3d46
 
2966210
2e445c2
 
e4f3d46
 
2966210
 
2e445c2
e4f3d46
 
885a86a
502a6b6
47f8251
e4f3d46
 
ef0a942
e4f3d46
 
 
 
47f8251
e4f3d46
 
 
 
 
 
 
 
 
 
 
 
47f8251
e4f3d46
 
47f8251
94cf8c8
e4f3d46
 
 
 
 
 
94cf8c8
2966210
2e445c2
2966210
94cf8c8
2966210
 
94cf8c8
47f8251
2966210
 
47f8251
94cf8c8
2966210
 
e4f3d46
 
 
 
2966210
 
94cf8c8
e4f3d46
2e445c2
94cf8c8
47f8251
 
 
94cf8c8
e4f3d46
 
 
47f8251
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import re
import time
import datetime
import traceback
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging

# ---------------------------------------------------------------------------
# 0.  Paths & basic logging helper
# ---------------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"

def log(msg: str):
    ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
    line = f"[{ts}] {msg}"
    print(line, flush=True)
    try:
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except Exception:
        pass  # tolerate logging failures

# ---------------------------------------------------------------------------
# 1.  Configuration constants
# ---------------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CONTEXT_TOKENS = 1800
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.6
MAX_INPUT_CH = 300

SYSTEM_MSG = (
    "You are **SchoolSpirit AI**, the official digital mascot of "
    "SchoolSpirit AI LLC.  Founded by Charles Norton in 2025, the company "
    "deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
    "turnkey GPU servers to K‑12 schools.\n\n"
    "RULES:\n"
    "• Friendly, concise (≤4 sentences unless prompted).\n"
    "• No personal data collection; no medical/legal/financial advice.\n"
    "• If uncertain, admit it & suggest human follow‑up.\n"
    "• Avoid profanity, politics, mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"

strip = lambda s: re.sub(r"\s+", " ", s.strip())

# ---------------------------------------------------------------------------
# 2.  Load tokenizer + model  (GPU FP‑16 → CPU)
# ---------------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
    log("Loading tokenizer …")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

    if torch.cuda.is_available():
        log("GPU detected → FP‑16")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="auto", torch_dtype=torch.float16
        )
    else:
        log("CPU fallback")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
        )

    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=TEMPERATURE,
        return_full_text=False,
        streaming=True,    # ← enable token-by-token streaming
    )
    MODEL_ERR = None
    log("Model loaded ✔")
except Exception as exc:
    MODEL_ERR = f"Model load error: {exc}"
    generator = None
    log(MODEL_ERR)

# ---------------------------------------------------------------------------
# 3.  Helper: build prompt under token budget (with fallback on error)
# ---------------------------------------------------------------------------
def build_prompt(raw_history: list[dict]) -> str:
    try:
        def render(msg):
            if msg["role"] == "system":
                return msg["content"]
            prefix = "User:" if msg["role"] == "user" else "AI:"
            return f"{prefix} {msg['content']}"

        system_msg = next(m for m in raw_history if m["role"] == "system")
        convo = [m for m in raw_history if m["role"] != "system"]

        while True:
            parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
            token_len = len(tokenizer.encode("\n".join(parts), add_special_tokens=False))
            if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
                break
            convo = convo[2:]
        return "\n".join(parts)

    except Exception:
        log("Error building prompt:\n" + traceback.format_exc())
        # Fallback: include system + last two messages
        sys_text = next((m["content"] for m in raw_history if m["role"]=="system"), "")
        tail = [m for m in raw_history if m["role"]!="system"][-2:]
        fallback = sys_text + "\n" + "\n".join(
            f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" for m in tail
        ) + "\nAI:"
        return fallback

# ---------------------------------------------------------------------------
# 4.  Chat callback (streaming generator + robust error handling)
# ---------------------------------------------------------------------------
def chat_fn(user_msg: str, display_history: list, state: dict):
    user_msg = strip(user_msg or "")
    if not user_msg:
        yield display_history, state
        return

    if len(user_msg) > MAX_INPUT_CH:
        display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
        yield display_history, state
        return

    if MODEL_ERR:
        display_history.append((user_msg, MODEL_ERR))
        yield display_history, state
        return

    try:
        # record user
        state["raw"].append({"role": "user", "content": user_msg})
        display_history.append((user_msg, ""))

        prompt = build_prompt(state["raw"])
        start = time.time()
        partial = ""

        # stream chunks
        for chunk in generator(prompt):
            try:
                new_text = strip(chunk.get("generated_text", ""))
                if "User:" in new_text:
                    new_text = new_text.split("User:", 1)[0].strip()
                partial += new_text
                display_history[-1] = (user_msg, partial)
                yield display_history, state
            except Exception:
                log("Malformed chunk:\n" + traceback.format_exc())
                continue

        # finalize
        full_reply = display_history[-1][1]
        state["raw"].append({"role": "assistant", "content": full_reply})
        log(f"Reply in {time.time() - start:.2f}s ({len(full_reply)} chars)")

    except Exception:
        log("Unexpected chat_fn error:\n" + traceback.format_exc())
        err = "Apologies—an internal error occurred. Please try again."
        display_history[-1] = (user_msg, err)
        state["raw"].append({"role": "assistant", "content": err})
        yield display_history, state

# ---------------------------------------------------------------------------
# 5.  Launch Gradio Blocks UI
# ---------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    gr.Markdown("### SchoolSpirit AI Chat")

    chatbot = gr.Chatbot(
        value=[{"role":"assistant","content":WELCOME_MSG}],
        height=480,
        label="SchoolSpirit AI",
        type="messages",  # use the new messages format
    )

    state = gr.State(
        {"raw": [
            {"role": "system", "content": SYSTEM_MSG},
            {"role": "assistant", "content": WELCOME_MSG},
        ]}
    )

    with gr.Row():
        txt = gr.Textbox(placeholder="Type your question here…", show_label=False, scale=4, lines=1)
        send_btn = gr.Button("Send", variant="primary")

    # Use streaming=True (not stream) per Gradio API
    send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)
    txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state], streaming=True)

if __name__ == "__main__":
    try:
        demo.launch()
    except Exception:
        log("UI launch error:\n" + traceback.format_exc())