Spaces:
Running
Running
import os | |
import pathlib | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from unsloth import FastLanguageModel, is_bfloat16_supported | |
import importlib | |
import random | |
from datasets import load_dataset | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
st.title("π§ Math LLM Demo") | |
st.text(f"Using device: {device}") | |
# === MODEL SELECTION === | |
MODEL_OPTIONS = { | |
"Vanilla GPT-2": "openai-community/gpt2", | |
"GPT2-Small-CPT-CL-IFT": "jonathantiedchen/GPT2-Small-CPT-CL-IFT", | |
"Mistral 7B+CPT+CL+IFT": "jonathantiedchen/MistralMath-CPT-IFT" | |
} | |
def load_models(): | |
models = {} | |
for name, path in MODEL_OPTIONS.items(): | |
if "mistral" in name.lower(): | |
try: | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=path, | |
max_seq_length=2048, | |
dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16, | |
load_in_4bit=True | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
FastLanguageModel.for_inference(model) | |
except Exception as e: | |
st.sidebar.error(f"β οΈ Failed to load Mistral model with Unsloth: {e}") | |
continue | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
model = AutoModelForCausalLM.from_pretrained(path).to(device) | |
model.eval() | |
models[name] = {"tokenizer": tokenizer, "model": model} | |
return models | |
st.sidebar.write("π₯ Load Models.") | |
models = load_models() | |
st.sidebar.write(f"β Successfully loaded models:{models}") | |
model_choice = st.selectbox("Choose a model:", list(MODEL_OPTIONS.keys())) | |
tokenizer = models[model_choice]["tokenizer"] | |
model = models[model_choice]["model"] | |
# === LOAD DATA === | |
def load_gsm8k_dataset(): | |
return load_dataset("openai/gsm8k", "main")["test"] | |
st.sidebar.write("π₯ Load GSM8K") | |
gsm8k_data = load_gsm8k_dataset() | |
st.sidebar.write("π GSM8K loaded:", len(gsm8k_data), "samples") | |
# === TABS === | |
tab1, tab2 = st.tabs(["π Manual Prompting", "π GSM8K Evaluation"]) | |
# === MANUAL GENERATION TAB === | |
with tab1: | |
prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?") | |
if st.button("Generate Response", key="manual"): | |
with st.sidebar.spinner("π Generating..."): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
output = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
response_only = generated_text[len(prompt):].strip() | |
st.subheader("π Prompt") | |
st.code(prompt) | |
st.subheader("π§ Model Output") | |
st.code(generated_text) | |
st.subheader("βοΈ Response Only") | |
st.success(response_only) | |
# === GSM8K TAB === | |
with tab2: | |
st.markdown("A random question from GSM8K will be shown. Click below to test the model.") | |
if st.button("Run GSM8K Sample"): | |
try: | |
with st.sidebar.spinner("π Generating..."): | |
sample = random.choice(gsm8k_data) | |
question = sample["question"] | |
gold_answer = sample["answer"] | |
inputs = tokenizer(question, return_tensors="pt").to(model.device) | |
st.markdown(f"Create Output") | |
output = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
response_only = generated_text[len(question):].strip() | |
st.subheader("π GSM8K Question") | |
st.markdown(question) | |
st.subheader("π Model Output") | |
st.markdown(generated_text) | |
st.subheader("βοΈ Response Only") | |
st.success(response_only) | |
st.subheader("β Gold Answer") | |
st.info(gold_answer) | |
except Exception as e: | |
st.error(f"Error: {e}") |