File size: 2,866 Bytes
e1aeada
 
 
 
 
 
 
 
 
20e9848
e1aeada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20e9848
 
e1aeada
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
SchoolSpirit AI – minimal public chatbot Space
---------------------------------------------
• Loads Meta’s Llama‑3 3 B‑Instruct (fits HF CPU Space).
• Uses Hugging Face transformers + Gradio; no external deps.
• Keeps prompt short and trims history to fit the model context.
• Gracefully handles model‑load or inference errors.
"""

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging

hf_logging.set_verbosity_error()  # keep Space logs clean

MODEL_ID   = "meta-llama/Llama-3.2-3B-Instruct"
MAX_TURNS  = 6          # retain last N exchanges
MAX_TOKENS = 220        # response length
SYSTEM_MSG = (
    "You are SchoolSpirit AI, the friendly digital mascot for a company that "
    "provides on‑prem AI chat mascots, fine‑tuning services, and turnkey GPU "
    "hardware for schools.  Keep answers concise, upbeat, and age‑appropriate.  "
    "If you don’t know, say so and suggest contacting a human.  Never request "
    "personal data."
)

# ---------------- Model ------------------------------------------------------
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model     = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="auto",          # auto‑detect CPU / GPU if available
        torch_dtype="auto"
    )
    gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=MAX_TOKENS,
        do_sample=True,
        temperature=0.7,
    )
except Exception as e:  # noqa: BLE001
    # Fatal startup failure – expose error in UI
    def chat(history, user_msg):
        return history + [(user_msg, f"Model load error: {e}")], ""
else:
    # ---------------- Chat handler ------------------------------------------
    def chat(history, user_msg):
        """Gradio ChatInterface callback."""
        # Trim history to last N turns
        if len(history) > MAX_TURNS:
            history = history[-MAX_TURNS:]

        # Build prompt
        prompt = SYSTEM_MSG + "\n"
        for u, a in history:
            prompt += f"User: {u}\nAI: {a}\n"
        prompt += f"User: {user_msg}\nAI:"

        # Generate
        try:
            completion = gen(prompt)[0]["generated_text"]
            reply = completion.split("AI:", 1)[-1].strip()
        except Exception as err:  # noqa: BLE001
            reply = "Sorry, something went wrong. Please try again later."
            hf_logging.get_logger("SchoolSpirit").error(str(err))

        history.append((user_msg, reply))
        return history, ""


# ---------------- UI ---------------------------------------------------------
gr.ChatInterface(
    chat,
    title="SchoolSpirit AI Chat",
    theme=gr.themes.Soft(primary_hue="blue"),  # light‑blue chat UI
).launch()