mohitkumarrajbadi's picture
New Framework Change
2bdd84f
raw
history blame
6 kB
import streamlit as st
import pandas as pd
import json
import os
from datetime import datetime
from utils import (
load_model,
get_hf_token,
simulate_training,
plot_training_metrics,
load_finetuned_model,
save_model
)
st.title("πŸ”₯ Fine-tune the Gemma Model")
# -------------------------------
# Finetuning Option Selection
# -------------------------------
finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"])
# -------------------------------
# Model Selection Logic
# -------------------------------
selected_model = None
saved_model_path = None
if finetune_option == "Fine-tune from scratch":
# Display Hugging Face model list
model_list = [
"google/gemma-3-1b-pt",
"google/gemma-3-1b-it",
"google/gemma-3-4b-pt",
"google/gemma-3-4b-it",
"google/gemma-3-12b-pt",
"google/gemma-3-12b-it",
"google/gemma-3-27b-pt",
"google/gemma-3-27b-it"
]
selected_model = st.selectbox("πŸ› οΈ Select Gemma Model to Fine-tune", model_list)
elif finetune_option == "Refinetune existing model":
# Dynamically list all saved models from the /models folder
model_dir = "models"
if os.path.exists(model_dir):
saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
else:
saved_models = []
if saved_models:
saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models)
saved_model_path = os.path.join(model_dir, saved_model_path)
st.success(f"βœ… Selected model for refinement: `{saved_model_path}`")
else:
st.warning("⚠️ No saved models found! Switching to fine-tuning from scratch.")
finetune_option = "Fine-tune from scratch"
# -------------------------------
# Dataset Selection
# -------------------------------
st.subheader("πŸ“š Dataset Selection")
# Dataset source selection
dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"])
dataset_path = "train_data.csv"
if dataset_option == "Upload New Dataset":
uploaded_file = st.file_uploader("πŸ“€ Upload Dataset (CSV or JSON)", type=["csv", "json"])
if uploaded_file is not None:
# Handle CSV or JSON upload
if uploaded_file.name.endswith(".csv"):
new_data = pd.read_csv(uploaded_file)
elif uploaded_file.name.endswith(".json"):
json_data = json.load(uploaded_file)
new_data = pd.json_normalize(json_data)
else:
st.error("❌ Unsupported file format. Please upload CSV or JSON.")
st.stop()
# Append or create new dataset
if os.path.exists(dataset_path):
new_data.to_csv(dataset_path, mode='a', index=False, header=False)
st.success(f"βœ… Data appended to `{dataset_path}`!")
else:
new_data.to_csv(dataset_path, index=False)
st.success(f"βœ… Dataset saved as `{dataset_path}`!")
elif dataset_option == "Use Existing Dataset (`train_data.csv`)":
if os.path.exists(dataset_path):
st.success("βœ… Using existing `train_data.csv` for fine-tuning.")
else:
st.error("❌ `train_data.csv` not found! Please upload a new dataset.")
st.stop()
# -------------------------------
# Hyperparameters Configuration
# -------------------------------
learning_rate = st.number_input("πŸ“Š Learning Rate", value=1e-4, format="%.5f")
batch_size = st.number_input("πŸ› οΈ Batch Size", value=16, step=1)
epochs = st.number_input("⏱️ Epochs", value=3, step=1)
# -------------------------------
# Fine-tuning Execution
# -------------------------------
if st.button("πŸš€ Start Fine-tuning"):
st.info(f"Fine-tuning process initiated...")
# Retrieve Hugging Face Token
hf_token = get_hf_token()
# Model loading logic
if finetune_option == "Refinetune existing model" and saved_model_path:
# Load the base model first
tokenizer, model = load_model("google/gemma-3-1b-it", hf_token)
# Load the saved model checkpoint for re-finetuning
model = load_finetuned_model(model, saved_model_path)
if model:
st.success(f"βœ… Loaded saved model: `{saved_model_path}` for refinement!")
else:
st.error("❌ Failed to load the saved model. Aborting.")
st.stop()
else:
# Fine-tune from scratch (load base model)
if not selected_model:
st.error("❌ Please select a model to fine-tune.")
st.stop()
tokenizer, model = load_model(selected_model, hf_token)
if model:
st.success(f"βœ… Base model loaded: `{selected_model}`")
else:
st.error("❌ Failed to load the base model. Aborting.")
st.stop()
# Simulate fine-tuning loop
progress_bar = st.progress(0)
training_placeholder = st.empty()
for epoch, losses, accs in simulate_training(epochs):
fig = plot_training_metrics(epoch, losses, accs)
training_placeholder.pyplot(fig)
progress_bar.progress(epoch / epochs)
# Save fine-tuned model with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
new_model_name = f"models/fine_tuned_model_{selected_model.replace('/', '_')}_{timestamp}.pt"
# Save the fine-tuned model
saved_model_path = save_model(model, new_model_name)
if saved_model_path:
st.success(f"βœ… Fine-tuning completed! Model saved as `{saved_model_path}`")
# Load the fine-tuned model for immediate inference
model = load_finetuned_model(model, saved_model_path)
if model:
st.success("πŸ› οΈ Fine-tuned model loaded and ready for inference!")
else:
st.error("❌ Failed to load the fine-tuned model for inference.")
else:
st.error("❌ Failed to save the fine-tuned model.")