math_bot / app.py
jonathantiedchen's picture
Update app.py
87de3b4 verified
import os
import pathlib
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from unsloth import FastLanguageModel, is_bfloat16_supported
import importlib
import random
from datasets import load_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.title("🧠 Math LLM Demo")
st.text(f"Using device: {device}")
# === MODEL SELECTION ===
MODEL_OPTIONS = {
"Vanilla GPT-2": "openai-community/gpt2",
"GPT2-Small-CPT-CL-IFT": "jonathantiedchen/GPT2-Small-CPT-CL-IFT",
"Mistral 7B+CPT+CL+IFT": "jonathantiedchen/MistralMath-CPT-IFT"
}
@st.cache_resource
def load_models():
models = {}
for name, path in MODEL_OPTIONS.items():
if "mistral" in name.lower():
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=path,
max_seq_length=2048,
dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16,
load_in_4bit=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
FastLanguageModel.for_inference(model)
except Exception as e:
st.sidebar.error(f"⚠️ Failed to load Mistral model with Unsloth: {e}")
continue
else:
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path).to(device)
model.eval()
models[name] = {"tokenizer": tokenizer, "model": model}
return models
st.sidebar.write("πŸ“₯ Load Models.")
models = load_models()
st.sidebar.write(f"βœ… Successfully loaded models:{models}")
model_choice = st.selectbox("Choose a model:", list(MODEL_OPTIONS.keys()))
tokenizer = models[model_choice]["tokenizer"]
model = models[model_choice]["model"]
# === LOAD DATA ===
@st.cache_resource
def load_gsm8k_dataset():
return load_dataset("openai/gsm8k", "main")["test"]
st.sidebar.write("πŸ“₯ Load GSM8K")
gsm8k_data = load_gsm8k_dataset()
st.sidebar.write("πŸ“Š GSM8K loaded:", len(gsm8k_data), "samples")
# === TABS ===
tab1, tab2 = st.tabs(["πŸ”“ Manual Prompting", "πŸ“Š GSM8K Evaluation"])
# === MANUAL GENERATION TAB ===
with tab1:
prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
if st.button("Generate Response", key="manual"):
with st.sidebar.spinner("πŸ”„ Generating..."):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
response_only = generated_text[len(prompt):].strip()
st.subheader("πŸ”Ž Prompt")
st.code(prompt)
st.subheader("🧠 Model Output")
st.code(generated_text)
st.subheader("βœ‚οΈ Response Only")
st.success(response_only)
# === GSM8K TAB ===
with tab2:
st.markdown("A random question from GSM8K will be shown. Click below to test the model.")
if st.button("Run GSM8K Sample"):
try:
with st.sidebar.spinner("πŸ”„ Generating..."):
sample = random.choice(gsm8k_data)
question = sample["question"]
gold_answer = sample["answer"]
inputs = tokenizer(question, return_tensors="pt").to(model.device)
st.markdown(f"Create Output")
output = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
response_only = generated_text[len(question):].strip()
st.subheader("πŸ“Œ GSM8K Question")
st.markdown(question)
st.subheader("πŸ” Model Output")
st.markdown(generated_text)
st.subheader("βœ‚οΈ Response Only")
st.success(response_only)
st.subheader("βœ… Gold Answer")
st.info(gold_answer)
except Exception as e:
st.error(f"Error: {e}")