|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader |
|
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation |
|
|
|
|
|
|
|
base_model = "paraphrase-multilingual-MiniLM-L12-v2" |
|
data_path = "src/data/clean_pairs.jsonl" |
|
|
|
|
|
dataset = load_dataset("json", data_files = "src/data/clean_pairs.jsonl", split = "train") |
|
|
|
|
|
all_samples = [ |
|
InputExample(texts = entry["texts"]) |
|
for entry in dataset |
|
] |
|
|
|
|
|
split_idx = int(len(all_samples) * 0.75) |
|
train_samples = all_samples[:split_idx] |
|
eval_samples = all_samples[split_idx:] |
|
|
|
|
|
model = SentenceTransformer(base_model) |
|
train_dataloader = DataLoader(train_samples, shuffle = True, batch_size = 32) |
|
train_loss = losses.MultipleNegativesRankingLoss(model) |
|
|
|
|
|
evaluator = evaluation.BinaryClassificationEvaluator.from_input_examples(eval_samples, name = "eval") |
|
|
|
|
|
model.fit( |
|
train_objectives = [(train_dataloader, train_loss)], |
|
epochs = 0.5, |
|
warmup_steps = 100, |
|
evaluator = evaluator, |
|
evaluation_steps = 1000, |
|
show_progress_bar = True |
|
) |
|
|
|
|
|
model.save("MiniDalaLM") |