phanerozoic commited on
Commit
dd7db97
·
verified ·
1 Parent(s): b9083a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, time, datetime, traceback
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
@@ -25,15 +25,13 @@ def log(msg: str):
25
  # Configuration
26
  # ---------------------------------------------------------------------------
27
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
28
- MAX_TURNS = 4 # retain last N user/AI exchanges
29
- MAX_TOKENS = 64
30
- MAX_INPUT_CH = 300
31
 
32
  SYSTEM_MSG = (
33
- "You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
34
  "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
35
  "mascots, offers custom fine‑tuning of language models, and ships turnkey "
36
- "GPU hardware to K‑12 schools.\n\n"
37
  "GUIDELINES:\n"
38
  "• Use a warm, encouraging tone fit for students, parents, and staff.\n"
39
  "• Keep replies short—no more than four sentences unless asked.\n"
@@ -41,19 +39,32 @@ SYSTEM_MSG = (
41
  "• Never collect personal data or provide medical, legal, or financial advice.\n"
42
  "• No profanity, politics, or mature themes."
43
  )
44
-
45
  WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
46
 
47
  # ---------------------------------------------------------------------------
48
- # Load model
49
  # ---------------------------------------------------------------------------
50
  hf_logging.set_verbosity_error()
51
  try:
52
- log("Loading model …")
53
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
54
- model = AutoModelForCausalLM.from_pretrained(
55
- MODEL_ID, device_map="auto", torch_dtype="auto"
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  gen = pipeline(
58
  "text-generation",
59
  model=model,
@@ -77,15 +88,8 @@ trim = lambda m: m if len(m) <= 1 + MAX_TURNS * 2 else [m[0]] + m[-MAX_TURNS * 2
77
 
78
 
79
  def chat_fn(user_msg: str, history: list):
80
- """
81
- Gradio passes:
82
- user_msg : str
83
- history : list[dict] -> [{'role':'assistant'|'user','content':...}, ...]
84
- Return a string; ChatInterface will append it as assistant message.
85
- """
86
  log(f"User sent {len(user_msg)} chars")
87
 
88
- # Inject system message once
89
  if not history or history[0]["role"] != "system":
90
  history.insert(0, {"role": "system", "content": SYSTEM_MSG})
91
 
@@ -131,7 +135,7 @@ gr.ChatInterface(
131
  chatbot=gr.Chatbot(
132
  height=480,
133
  type="messages",
134
- value=[{"role": "assistant", "content": WELCOME_MSG}], # preloaded welcome
135
  ),
136
  title="SchoolSpirit AI Chat",
137
  theme=gr.themes.Soft(primary_hue="blue"),
 
1
+ import os, re, time, datetime, traceback, torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
 
25
  # Configuration
26
  # ---------------------------------------------------------------------------
27
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
28
+ MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
 
 
29
 
30
  SYSTEM_MSG = (
31
+ "You are **SchoolSpiritAI**, the digital mascot for SchoolSpiritAILLC, "
32
  "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
33
  "mascots, offers custom fine‑tuning of language models, and ships turnkey "
34
+ "PC's with preinstalled language models to K‑12 schools.\n\n"
35
  "GUIDELINES:\n"
36
  "• Use a warm, encouraging tone fit for students, parents, and staff.\n"
37
  "• Keep replies short—no more than four sentences unless asked.\n"
 
39
  "• Never collect personal data or provide medical, legal, or financial advice.\n"
40
  "• No profanity, politics, or mature themes."
41
  )
 
42
  WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
43
 
44
  # ---------------------------------------------------------------------------
45
+ # Load model (GPU FP‑16 if available → CPU fallback)
46
  # ---------------------------------------------------------------------------
47
  hf_logging.set_verbosity_error()
48
  try:
49
+ log("Loading tokenizer …")
50
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
51
+
52
+ if torch.cuda.is_available():
53
+ log("GPU detected → loading model in FP‑16")
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_ID,
56
+ device_map="auto", # put layers on available GPU(s)
57
+ torch_dtype=torch.float16,
58
+ )
59
+ else:
60
+ log("No GPU → loading model on CPU (FP‑32)")
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ MODEL_ID,
63
+ device_map="cpu",
64
+ torch_dtype="auto",
65
+ low_cpu_mem_usage=True,
66
+ )
67
+
68
  gen = pipeline(
69
  "text-generation",
70
  model=model,
 
88
 
89
 
90
  def chat_fn(user_msg: str, history: list):
 
 
 
 
 
 
91
  log(f"User sent {len(user_msg)} chars")
92
 
 
93
  if not history or history[0]["role"] != "system":
94
  history.insert(0, {"role": "system", "content": SYSTEM_MSG})
95
 
 
135
  chatbot=gr.Chatbot(
136
  height=480,
137
  type="messages",
138
+ value=[{"role": "assistant", "content": WELCOME_MSG}],
139
  ),
140
  title="SchoolSpirit AI Chat",
141
  theme=gr.themes.Soft(primary_hue="blue"),