Spaces:
Build error
Build error
File size: 5,997 Bytes
2bdd84f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import streamlit as st
import pandas as pd
import json
import os
from datetime import datetime
from utils import (
load_model,
get_hf_token,
simulate_training,
plot_training_metrics,
load_finetuned_model,
save_model
)
st.title("π₯ Fine-tune the Gemma Model")
# -------------------------------
# Finetuning Option Selection
# -------------------------------
finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"])
# -------------------------------
# Model Selection Logic
# -------------------------------
selected_model = None
saved_model_path = None
if finetune_option == "Fine-tune from scratch":
# Display Hugging Face model list
model_list = [
"google/gemma-3-1b-pt",
"google/gemma-3-1b-it",
"google/gemma-3-4b-pt",
"google/gemma-3-4b-it",
"google/gemma-3-12b-pt",
"google/gemma-3-12b-it",
"google/gemma-3-27b-pt",
"google/gemma-3-27b-it"
]
selected_model = st.selectbox("π οΈ Select Gemma Model to Fine-tune", model_list)
elif finetune_option == "Refinetune existing model":
# Dynamically list all saved models from the /models folder
model_dir = "models"
if os.path.exists(model_dir):
saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
else:
saved_models = []
if saved_models:
saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models)
saved_model_path = os.path.join(model_dir, saved_model_path)
st.success(f"β
Selected model for refinement: `{saved_model_path}`")
else:
st.warning("β οΈ No saved models found! Switching to fine-tuning from scratch.")
finetune_option = "Fine-tune from scratch"
# -------------------------------
# Dataset Selection
# -------------------------------
st.subheader("π Dataset Selection")
# Dataset source selection
dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"])
dataset_path = "train_data.csv"
if dataset_option == "Upload New Dataset":
uploaded_file = st.file_uploader("π€ Upload Dataset (CSV or JSON)", type=["csv", "json"])
if uploaded_file is not None:
# Handle CSV or JSON upload
if uploaded_file.name.endswith(".csv"):
new_data = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".json"):
json_data = json.load(uploaded_file)
new_data = pd.json_normalize(json_data)
else:
st.error("β Unsupported file format. Please upload CSV or JSON.")
st.stop()
# Append or create new dataset
if os.path.exists(dataset_path):
new_data.to_csv(dataset_path, mode='a', index=False, header=False)
st.success(f"β
Data appended to `{dataset_path}`!")
else:
new_data.to_csv(dataset_path, index=False)
st.success(f"β
Dataset saved as `{dataset_path}`!")
elif dataset_option == "Use Existing Dataset (`train_data.csv`)":
if os.path.exists(dataset_path):
st.success("β
Using existing `train_data.csv` for fine-tuning.")
else:
st.error("β `train_data.csv` not found! Please upload a new dataset.")
st.stop()
# -------------------------------
# Hyperparameters Configuration
# -------------------------------
learning_rate = st.number_input("π Learning Rate", value=1e-4, format="%.5f")
batch_size = st.number_input("π οΈ Batch Size", value=16, step=1)
epochs = st.number_input("β±οΈ Epochs", value=3, step=1)
# -------------------------------
# Fine-tuning Execution
# -------------------------------
if st.button("π Start Fine-tuning"):
st.info(f"Fine-tuning process initiated...")
# Retrieve Hugging Face Token
hf_token = get_hf_token()
# Model loading logic
if finetune_option == "Refinetune existing model" and saved_model_path:
# Load the base model first
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token)
# Load the saved model checkpoint for re-finetuning
model = load_finetuned_model(model, saved_model_path)
if model:
st.success(f"β
Loaded saved model: `{saved_model_path}` for refinement!")
else:
st.error("β Failed to load the saved model. Aborting.")
st.stop()
else:
# Fine-tune from scratch (load base model)
if not selected_model:
st.error("β Please select a model to fine-tune.")
st.stop()
tokenizer, model = load_model(selected_model, hf_token)
if model:
st.success(f"β
Base model loaded: `{selected_model}`")
else:
st.error("β Failed to load the base model. Aborting.")
st.stop()
# Simulate fine-tuning loop
progress_bar = st.progress(0)
training_placeholder = st.empty()
for epoch, losses, accs in simulate_training(epochs):
fig = plot_training_metrics(epoch, losses, accs)
training_placeholder.pyplot(fig)
progress_bar.progress(epoch / epochs)
# Save fine-tuned model with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
new_model_name = f"models/fine_tuned_model_{selected_model.replace('/', '_')}_{timestamp}.pt"
# Save the fine-tuned model
saved_model_path = save_model(model, new_model_name)
if saved_model_path:
st.success(f"β
Fine-tuning completed! Model saved as `{saved_model_path}`")
# Load the fine-tuned model for immediate inference
model = load_finetuned_model(model, saved_model_path)
if model:
st.success("π οΈ Fine-tuned model loaded and ready for inference!")
else:
st.error("β Failed to load the fine-tuned model for inference.")
else:
st.error("β Failed to save the fine-tuned model.")
|