File size: 5,485 Bytes
e8212fa
 
502a6b6
e1aeada
 
94cf8c8
 
 
9843b35
e8212fa
 
 
502a6b6
 
e8212fa
 
 
 
 
 
 
 
 
94cf8c8
 
 
e8212fa
94cf8c8
e8212fa
 
94cf8c8
e8212fa
94cf8c8
 
 
 
 
 
e8212fa
94cf8c8
502a6b6
94cf8c8
ef0a942
 
94cf8c8
 
ef0a942
94cf8c8
 
 
 
502a6b6
 
e8212fa
502a6b6
e8212fa
502a6b6
e8212fa
502a6b6
 
 
 
e8212fa
502a6b6
 
 
ef0a942
502a6b6
9007cad
 
 
e8212fa
502a6b6
9961fac
9007cad
e8212fa
 
 
 
 
 
61ca5d6
94cf8c8
 
 
 
e8212fa
94cf8c8
 
 
e8212fa
 
94cf8c8
318dc96
502a6b6
9007cad
94cf8c8
e8212fa
 
94cf8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885a86a
502a6b6
 
e8212fa
 
502a6b6
e8212fa
502a6b6
ef0a942
94cf8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os, re, time, datetime, traceback, torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging

# -------------------------------------------------------------------
# 1.  Logging helpers
# -------------------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"


def log(msg: str):
    ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
    line = f"[{ts}] {msg}"
    print(line, flush=True)
    try:
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except FileNotFoundError:
        pass


# -------------------------------------------------------------------
# 2.  Configuration
# -------------------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300

SYSTEM_MSG = (
    "You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
    "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
    "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
    "K‑12 schools.\n\n"
    "GUIDELINES:\n"
    "• Warm, encouraging tone for students, parents, staff.\n"
    "• Replies ≤ 4 sentences unless asked for detail.\n"
    "• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
    "• No personal‑data collection or sensitive advice.\n"
    "• No profanity, politics, or mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"


def strip(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip())


# -------------------------------------------------------------------
# 3.  Load model (GPU FP‑16 → CPU fallback)
# -------------------------------------------------------------------
hf_logging.set_verbosity_error()
try:
    log("Loading tokenizer …")
    tok = AutoTokenizer.from_pretrained(MODEL_ID)

    if torch.cuda.is_available():
        log("GPU detected → FP‑16")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="auto", torch_dtype=torch.float16
        )
    else:
        log("CPU fallback")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
        )

    gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tok,
        max_new_tokens=MAX_TOKENS,
        do_sample=True,
        temperature=0.6,
    )
    MODEL_ERR = None
    log("Model loaded ✔")
except Exception as exc:  # noqa: BLE001
    MODEL_ERR, gen = f"Model load error: {exc}", None
    log(MODEL_ERR)


# -------------------------------------------------------------------
# 4.  Chat callback
# -------------------------------------------------------------------
def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
    """
    history: list of (user, assistant) tuples (Gradio default)
    state  : dict carrying system_prompt + raw_history for the model
    Returns updated history (for UI) and state (for next round)
    """
    if MODEL_ERR:
        return history + [(user_msg, MODEL_ERR)], state

    user_msg = strip(user_msg or "")
    if not user_msg:
        return history + [(user_msg, "Please type something.")], state
    if len(user_msg) > MAX_INPUT_CH:
        warn = f"Message too long (>{MAX_INPUT_CH} chars)."
        return history + [(user_msg, warn)], state

    # ------------------------------------------------ Prompt assembly
    raw_hist = state.get("raw", [])
    raw_hist.append({"role": "user", "content": user_msg})
    # keep system + last N exchanges
    convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
    raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo

    prompt = "\n".join(
        [
            m["content"]
            if m["role"] == "system"
            else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
            for m in raw_hist
        ]
        + ["AI:"]
    )

    try:
        raw = gen(prompt)[0]["generated_text"]
        reply = strip(raw.split("AI:", 1)[-1])
        reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
    except Exception:
        log("❌ Inference error:\n" + traceback.format_exc())
        reply = "Sorry—backend crashed. Please try again later."

    # ------------------------------------------------ Update state + UI history
    raw_hist.append({"role": "assistant", "content": reply})
    state["raw"] = raw_hist
    history.append((user_msg, reply))
    return history, state


# -------------------------------------------------------------------
# 5.  Launch
# -------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    chatbot = gr.Chatbot(
        value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI"
    )
    state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
    with gr.Row():
        txt = gr.Textbox(
            scale=4, placeholder="Type your question here...", show_label=False
        )
        send = gr.Button("Send", variant="primary")

    send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
    txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])

demo.launch()