minidalalm / src /train_minilm.py
crossroderick's picture
Upload folder using huggingface_hub
83daab2 verified
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
# Path config
base_model = "paraphrase-multilingual-MiniLM-L12-v2"
data_path = "src/data/clean_pairs.jsonl"
# Load the full dataset and convert to input examples
dataset = load_dataset("json", data_files = "src/data/clean_pairs.jsonl", split = "train")
# Create input examples
all_samples = [
InputExample(texts = entry["texts"])
for entry in dataset
]
# Split into train and eval sets (75/25)
split_idx = int(len(all_samples) * 0.75)
train_samples = all_samples[:split_idx]
eval_samples = all_samples[split_idx:]
# Model and loss
model = SentenceTransformer(base_model)
train_dataloader = DataLoader(train_samples, shuffle = True, batch_size = 32)
train_loss = losses.MultipleNegativesRankingLoss(model)
# Evaluation setup
evaluator = evaluation.BinaryClassificationEvaluator.from_input_examples(eval_samples, name = "eval")
# Train with eval
model.fit(
train_objectives = [(train_dataloader, train_loss)],
epochs = 0.5,
warmup_steps = 100,
evaluator = evaluator,
evaluation_steps = 1000,
show_progress_bar = True
)
# Save final model
model.save("MiniDalaLM")