|
|
|
""" |
|
Fine-tuning script for SmolLM2-135M model using Unsloth. |
|
|
|
This script demonstrates how to: |
|
1. Install and configure Unsloth |
|
2. Prepare and format training data |
|
3. Configure and run the training process |
|
4. Save and evaluate the model |
|
|
|
To run this script: |
|
1. Install dependencies: pip install -r requirements.txt |
|
2. Run: python train.py |
|
""" |
|
|
|
import logging |
|
import os |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
from unsloth import FastLanguageModel, is_bfloat16_supported |
|
from unsloth.chat_templates import get_chat_template |
|
|
|
|
|
|
|
from datasets import ( |
|
Dataset, |
|
DatasetDict, |
|
IterableDataset, |
|
IterableDatasetDict, |
|
load_dataset, |
|
) |
|
from transformers import ( |
|
AutoTokenizer, |
|
DataCollatorForLanguageModeling, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
from trl import SFTTrainer |
|
|
|
|
|
def setup_logging(): |
|
"""Configure logging for the training process.""" |
|
|
|
log_dir = Path("logs") |
|
log_dir.mkdir(exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
log_file = log_dir / f"training_{timestamp}.log" |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
handlers=[logging.FileHandler(log_file), logging.StreamHandler()], |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Logging initialized. Log file: {log_file}") |
|
return logger |
|
|
|
|
|
logger = setup_logging() |
|
|
|
|
|
def install_dependencies(): |
|
"""Install required dependencies.""" |
|
logger.info("Installing dependencies...") |
|
try: |
|
os.system( |
|
'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"' |
|
) |
|
os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes") |
|
logger.info("Dependencies installed successfully") |
|
except Exception as e: |
|
logger.error(f"Error installing dependencies: {e}") |
|
raise |
|
|
|
|
|
def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]: |
|
"""Load and configure the model.""" |
|
logger.info("Loading model and tokenizer...") |
|
try: |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=cfg.model.name, |
|
max_seq_length=cfg.model.max_seq_length, |
|
dtype=cfg.model.dtype, |
|
load_in_4bit=cfg.model.load_in_4bit, |
|
) |
|
logger.info("Base model loaded successfully") |
|
|
|
|
|
model = FastLanguageModel.get_peft_model( |
|
model, |
|
r=cfg.peft.r, |
|
target_modules=cfg.peft.target_modules, |
|
lora_alpha=cfg.peft.lora_alpha, |
|
lora_dropout=cfg.peft.lora_dropout, |
|
bias=cfg.peft.bias, |
|
use_gradient_checkpointing=cfg.peft.use_gradient_checkpointing, |
|
random_state=cfg.peft.random_state, |
|
use_rslora=cfg.peft.use_rslora, |
|
loftq_config=cfg.peft.loftq_config, |
|
) |
|
logger.info("LoRA configuration applied successfully") |
|
|
|
return model, tokenizer |
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
raise |
|
|
|
|
|
def load_and_format_dataset( |
|
tokenizer: AutoTokenizer, |
|
cfg: DictConfig, |
|
) -> tuple[ |
|
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer |
|
]: |
|
"""Load and format the training dataset.""" |
|
logger.info("Loading and formatting dataset...") |
|
try: |
|
|
|
dataset = load_dataset("xingyaoww/code-act", split="codeact") |
|
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples") |
|
|
|
|
|
dataset = dataset.train_test_split(test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed) |
|
logger.info( |
|
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets" |
|
) |
|
|
|
|
|
tokenizer = get_chat_template( |
|
tokenizer, |
|
chat_template="chatml", |
|
mapping={ |
|
"role": "from", |
|
"content": "value", |
|
"user": "human", |
|
"assistant": "gpt", |
|
}, |
|
map_eos_token=True, |
|
) |
|
logger.info("Chat template configured successfully") |
|
|
|
def formatting_prompts_func(examples): |
|
convos = examples["conversations"] |
|
texts = [ |
|
tokenizer.apply_chat_template( |
|
convo, tokenize=False, add_generation_prompt=False |
|
) |
|
for convo in convos |
|
] |
|
return {"text": texts} |
|
|
|
|
|
dataset = DatasetDict( |
|
{ |
|
"train": dataset["train"].map(formatting_prompts_func, batched=True), |
|
"validation": dataset["test"].map( |
|
formatting_prompts_func, batched=True |
|
), |
|
} |
|
) |
|
logger.info("Dataset formatting completed successfully") |
|
|
|
return dataset, tokenizer |
|
except Exception as e: |
|
logger.error(f"Error loading/formatting dataset: {e}") |
|
raise |
|
|
|
|
|
def create_trainer( |
|
model: FastLanguageModel, |
|
tokenizer: AutoTokenizer, |
|
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], |
|
cfg: DictConfig, |
|
) -> Trainer: |
|
"""Create and configure the SFTTrainer.""" |
|
logger.info("Creating trainer...") |
|
try: |
|
|
|
training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True) |
|
|
|
training_args_dict.update({ |
|
"fp16": not is_bfloat16_supported(), |
|
"bf16": is_bfloat16_supported(), |
|
}) |
|
training_args = TrainingArguments(**training_args_dict) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, |
|
**cfg.training.sft.data_collator, |
|
) |
|
|
|
|
|
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True) |
|
sft_config.pop('data_collator', None) |
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["validation"], |
|
args=training_args, |
|
data_collator=data_collator, |
|
**sft_config, |
|
) |
|
logger.info("Trainer created successfully") |
|
return trainer |
|
except Exception as e: |
|
logger.error(f"Error creating trainer: {e}") |
|
raise |
|
|
|
|
|
@hydra.main(version_base=None, config_path="conf", config_name="config") |
|
def main(cfg: DictConfig) -> None: |
|
"""Main training function.""" |
|
try: |
|
logger.info("Starting training process...") |
|
logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}") |
|
|
|
|
|
install_dependencies() |
|
|
|
|
|
model, tokenizer = load_model(cfg) |
|
|
|
|
|
dataset, tokenizer = load_and_format_dataset(tokenizer, cfg) |
|
|
|
|
|
trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg) |
|
|
|
|
|
if cfg.train: |
|
logger.info("Starting training...") |
|
trainer.train() |
|
|
|
|
|
logger.info(f"Saving final model to {cfg.output.dir}...") |
|
trainer.save_model(cfg.output.dir) |
|
|
|
|
|
final_metrics = trainer.state.log_history[-1] |
|
logger.info("\nTraining completed!") |
|
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}") |
|
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}") |
|
else: |
|
logger.info("Training skipped as train=False") |
|
|
|
except Exception as e: |
|
logger.error(f"Error in main training process: {e}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|