File size: 4,262 Bytes
e8212fa
 
502a6b6
e1aeada
 
e8212fa
9843b35
e8212fa
 
 
502a6b6
 
e8212fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502a6b6
e8212fa
ef0a942
61ca5d6
ef0a942
 
e8212fa
502a6b6
 
e8212fa
502a6b6
e8212fa
502a6b6
e8212fa
502a6b6
 
 
 
e8212fa
502a6b6
 
 
ef0a942
502a6b6
9007cad
 
 
e8212fa
502a6b6
9961fac
e8212fa
9007cad
e8212fa
 
 
 
 
 
61ca5d6
e8212fa
 
 
 
 
 
 
318dc96
502a6b6
9007cad
e8212fa
 
 
 
61ca5d6
e8212fa
 
 
 
 
 
 
 
 
885a86a
502a6b6
 
e8212fa
 
502a6b6
e8212fa
502a6b6
ef0a942
e8212fa
 
 
ef0a942
e8212fa
9961fac
 
e8212fa
 
 
 
 
 
 
 
9961fac
 
e8212fa
9961fac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os, re, time, datetime, traceback, torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging

# ---------- Logging ---------------------------------------------------------
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"


def log(msg: str):
    ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
    line = f"[{ts}] {msg}"
    print(line, flush=True)
    try:
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except FileNotFoundError:
        pass


# ---------- Config ----------------------------------------------------------
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 6, 128, 400

SYSTEM_MSG = (
    "You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
    "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
    "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to schools.\n\n"
    "Guidelines:\n"
    "• Warm, concise answers (max 4 sentences).\n"
    "• No personal‑data collection or sensitive advice.\n"
    "• If unsure, say so and suggest a human follow‑up.\n"
    "• Avoid profanity, politics, or mature themes."
)
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"

strip = lambda s: re.sub(r"\s+", " ", s.strip())


# ---------- Load model (GPU FP‑16 → CPU fallback) ---------------------------
hf_logging.set_verbosity_error()
try:
    log("Loading tokenizer …")
    tok = AutoTokenizer.from_pretrained(MODEL_ID)

    if torch.cuda.is_available():
        log("GPU detected → FP‑16")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="auto", torch_dtype=torch.float16
        )
    else:
        log("CPU fallback")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
        )

    gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tok,
        max_new_tokens=MAX_TOKENS,
        do_sample=True,
        temperature=0.6,
        pad_token_id=tok.eos_token_id,
    )
    MODEL_ERR = None
    log("Model loaded ✔")
except Exception as exc:  # noqa: BLE001
    MODEL_ERR, gen = f"Model load error: {exc}", None
    log(MODEL_ERR)


# ---------- Chat callback ---------------------------------------------------
def chat_fn(user_msg: str, history: list[dict]):
    """
    history comes in/out as list[{'role':'user'|'assistant','content':str}, …]
    """
    if MODEL_ERR:
        return history + [{"role": "assistant", "content": MODEL_ERR}]

    user_msg = strip(user_msg or "")
    if not user_msg:
        return history + [{"role": "assistant", "content": "Please type something."}]
    if len(user_msg) > MAX_INPUT_CH:
        warn = f"Message too long (>{MAX_INPUT_CH} chars)."
        return history + [{"role": "assistant", "content": warn}]

    # Append user to history
    history.append({"role": "user", "content": user_msg})

    # Keep system + last N messages
    convo = [m for m in history if m["role"] != "system"][-MAX_TURNS * 2 :]
    prompt_parts = [SYSTEM_MSG] + [
        f"{'User' if m['role']=='user' else 'AI'}: {m['content']}" for m in convo
    ] + ["AI:"]
    prompt = "\n".join(prompt_parts)

    try:
        raw = gen(prompt)[0]["generated_text"]
        reply = strip(raw.split("AI:", 1)[-1])
        reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
    except Exception:
        log("❌ Inference error:\n" + traceback.format_exc())
        reply = "Sorry—backend crashed. Please try again later."

    history.append({"role": "assistant", "content": reply})
    return history


# ---------- Launch ----------------------------------------------------------
gr.ChatInterface(
    fn=chat_fn,
    chatbot=gr.Chatbot(
        height=480,
        type="messages",
        value=[
            {"role": "assistant", "content": WELCOME_MSG}
        ],  # ONE welcome bubble
    ),
    additional_inputs=None,
    title="SchoolSpirit AI Chat",
    theme=gr.themes.Soft(primary_hue="blue"),
    examples=None,
).launch()